def testCloneModelWithDefinitionOnly(self):
    checkpointMgr = ModelCheckpointMgr()

    modelID = uuid.uuid1().hex
    destModelID = uuid.uuid1().hex

    # Create the source model with meta-info only (no checkpoint)
    modelDef = {'a': 1, 'b': 2, 'c':3}
    checkpointMgr.define(modelID, modelDef)

    # Clone the source model
    checkpointMgr.clone(modelID, destModelID)

    # Verify that the destination model's definition is the same as the
    # source model's
    destModelDef = checkpointMgr.loadModelDefinition(destModelID)
    self.assertDictEqual(destModelDef, modelDef)

    # Calling load when the model checkpoint doesn't exist should raise an
    #  exception
    with self.assertRaises(ModelNotFound):
      checkpointMgr.load(destModelID)

    # Calling clone when the destination model archive already exists should
    # raise an exception
    with self.assertRaises(ModelAlreadyExists):
      checkpointMgr.clone(modelID, destModelID)
    def testCloneModelWithDefinitionOnly(self):
        checkpointMgr = ModelCheckpointMgr()

        modelID = uuid.uuid1().hex
        destModelID = uuid.uuid1().hex

        # Create the source model with meta-info only (no checkpoint)
        modelDef = {'a': 1, 'b': 2, 'c': 3}
        checkpointMgr.define(modelID, modelDef)

        # Clone the source model
        checkpointMgr.clone(modelID, destModelID)

        # Verify that the destination model's definition is the same as the
        # source model's
        destModelDef = checkpointMgr.loadModelDefinition(destModelID)
        self.assertDictEqual(destModelDef, modelDef)

        # Calling load when the model checkpoint doesn't exist should raise an
        #  exception
        with self.assertRaises(ModelNotFound):
            checkpointMgr.load(destModelID)

        # Calling clone when the destination model archive already exists should
        # raise an exception
        with self.assertRaises(ModelAlreadyExists):
            checkpointMgr.clone(modelID, destModelID)
  def testCloneModelWithCheckpoint(self):
    checkpointMgr = ModelCheckpointMgr()

    modelID = uuid.uuid1().hex
    destModelID = uuid.uuid1().hex

    # Create the source model with meta-info only (no checkpoint)
    modelDef = {'a': 1, 'b': 2, 'c':3}
    checkpointMgr.define(modelID, modelDef)

    # Create a model that we can clone
    model1 = ModelFactory.create(self._getModelParams("variant1"))
    checkpointMgr.save(modelID, model1, attributes="attributes1")

    # Clone the source model
    checkpointMgr.clone(modelID, destModelID)

    # Discard the source model checkpoint
    checkpointMgr.remove(modelID)

    # Verify that the destination model's definition is the same as the
    # source model's
    destModelDef = checkpointMgr.loadModelDefinition(destModelID)
    self.assertDictEqual(destModelDef, modelDef)

    # Verify that the destination model's attributes match the source's
    attributes = checkpointMgr.loadCheckpointAttributes(destModelID)
    self.assertEqual(attributes, "attributes1")

    # Attempt to load the cloned model from checkpoint
    model = checkpointMgr.load(destModelID)
    self.assertEqual(str(model.getFieldInfo()), str(model1.getFieldInfo()))
    def testCloneModelWithCheckpoint(self):
        checkpointMgr = ModelCheckpointMgr()

        modelID = uuid.uuid1().hex
        destModelID = uuid.uuid1().hex

        # Create the source model with meta-info only (no checkpoint)
        modelDef = {'a': 1, 'b': 2, 'c': 3}
        checkpointMgr.define(modelID, modelDef)

        # Create a model that we can clone
        model1 = ModelFactory.create(self._getModelParams("variant1"))
        checkpointMgr.save(modelID, model1, attributes="attributes1")

        # Clone the source model
        checkpointMgr.clone(modelID, destModelID)

        # Discard the source model checkpoint
        checkpointMgr.remove(modelID)

        # Verify that the destination model's definition is the same as the
        # source model's
        destModelDef = checkpointMgr.loadModelDefinition(destModelID)
        self.assertDictEqual(destModelDef, modelDef)

        # Verify that the destination model's attributes match the source's
        attributes = checkpointMgr.loadCheckpointAttributes(destModelID)
        self.assertEqual(attributes, "attributes1")

        # Attempt to load the cloned model from checkpoint
        model = checkpointMgr.load(destModelID)
        self.assertEqual(str(model.getFieldInfo()), str(model1.getFieldInfo()))
Ejemplo n.º 5
0
class _ModelArchiver(object):
    """ Helper class for loading/creating and checkpointing model
  """

    # Name of the attribute that is stored as an integral component of the
    # checkpoint. This attribute contains a list of batchIDs that were processed
    # since the previous checkpoint.
    _BATCH_IDS_CHECKPOINT_ATTR_NAME = "batchIDs"

    # Name of the attribute that is stored as an integral component of the
    # checkpoint. It contains a list of ModelInputRow objects processed since
    # the last full checkpoint for use in preparing the last incremental
    # model checkpoint for new input. The value is in pickle string format.
    _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME = "incrementalInputSamples"

    _MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS = 100

    def __init__(self, modelID):
        """
    :param modelID: model ID; string
    """
        self._modelID = modelID

        # The model object from OPF ModelFactory; set up by the loadModel() method
        self._model = None

        # The _InputRowEncoder object; set up by the loadModel() method
        self._inputRowEncoder = None

        self._checkpointMgr = ModelCheckpointMgr()

        # True if model was loaded from existing checkpoint or after a full
        # checkpoint is made; False initially if model was created from
        # params. This informs us whether an incremental checkpoint is possible.
        self._hasCheckpoint = False

        self._modelCheckpointBatchIDSetCache = None

        # Input data samples that have accumulated since last full checkpoint
        self._inputSamplesSinceLastFullCheckpointCache = None

    @property
    def model(self):
        """ An OPF Model object or None if not loaded yet """
        return self._model

    @property
    def inputRowEncoder(self):
        """ An _InputRowEncoder object or None if model not loaded yet """
        return self._inputRowEncoder

    @property
    def modelCheckpointBatchIDSet(self):
        """ A sequence of input batch identifiers associated with current
    model checkpoint
    """
        if self._modelCheckpointBatchIDSetCache is None:
            self._loadCheckpointAttributes()
        return self._modelCheckpointBatchIDSetCache

    @property
    def checkpointMgr(self):
        return self._checkpointMgr

    @property
    def _inputSamplesSinceLastFullCheckpoint(self):
        if self._inputSamplesSinceLastFullCheckpointCache is None:
            self._loadCheckpointAttributes()
        return self._inputSamplesSinceLastFullCheckpointCache

    @_inputSamplesSinceLastFullCheckpoint.setter
    def _inputSamplesSinceLastFullCheckpoint(self, value):
        self._inputSamplesSinceLastFullCheckpointCache = value

    @classmethod
    def _encodeDataSamples(cls, dataSamples):
        """
    :param dataSamples: a sequence of data samples to be saved as the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a string encoding of the data samples
    """
        return base64.standard_b64encode(
            pickle.dumps(dataSamples, pickle.HIGHEST_PROTOCOL))

    @classmethod
    def _decodeDataSamples(cls, dataSamples):
        """
    :param dataSamples: string-encoded data samples from the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a sequence of data samples
    """
        return pickle.loads(base64.standard_b64decode(dataSamples))

    def _loadCheckpointAttributes(self):
        # Load the checkpoint attributes
        try:
            checkpointAttributes = self._checkpointMgr.loadCheckpointAttributes(
                self._modelID)
        except model_checkpoint_mgr.ModelNotFound:
            self._modelCheckpointBatchIDSetCache = set()
            self._inputSamplesSinceLastFullCheckpoint = []
        else:
            self._modelCheckpointBatchIDSetCache = set(
                checkpointAttributes[self._BATCH_IDS_CHECKPOINT_ATTR_NAME])

            inputSamples = checkpointAttributes.get(
                self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME)
            if inputSamples:
                self._inputSamplesSinceLastFullCheckpoint = self._decodeDataSamples(
                    inputSamples)
            else:
                self._inputSamplesSinceLastFullCheckpoint = []

    def loadModel(self):
        """ Load the model and construct the input row encoder. On success,
    the loaded model may be accessed via the `model` attribute

    :raises: model_checkpoint_mgr.ModelNotFound
    """
        if self._model is not None:
            return

        modelDefinition = None

        # Load the model
        try:
            self._model = self._checkpointMgr.load(self._modelID)
            self._hasCheckpoint = True
        except model_checkpoint_mgr.ModelNotFound:
            # So, we didn't have a checkpoint... try to create our model from model
            # definition params
            self._hasCheckpoint = False
            try:
                modelDefinition = self._checkpointMgr.loadModelDefinition(
                    self._modelID)
            except model_checkpoint_mgr.ModelNotFound:
                raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                                        msg="modelID=%s not found" %
                                        (self._modelID))
            else:
                modelParams = modelDefinition["modelParams"]

            # TODO: when creating the model from params, do we need to call
            #   its model.setFieldStatistics() method? And where will the
            #   fieldStats come from, anyway?
            self._model = ModelFactory.create(
                modelConfig=modelParams["modelConfig"])
            self._model.enableLearning()
            self._model.enableInference(modelParams["inferenceArgs"])

        # Construct the object for converting a flat input row into a format
        # that is consumable by an OPF model
        try:
            if modelDefinition is None:
                modelDefinition = self._checkpointMgr.loadModelDefinition(
                    self._modelID)
        except model_checkpoint_mgr.ModelNotFound:
            raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                                    msg="modelID=%s not found" %
                                    (self._modelID))
        else:
            inputSchema = modelDefinition["inputSchema"]

        # Convert it to a sequence of FieldMetaInfo instances
        # NOTE: if loadMetaInfo didn't raise, we expect "inputSchema" to be
        #   present; it would be a logic error if it isn't.
        inputFieldsMeta = tuple(FieldMetaInfo(*f) for f in inputSchema)

        self._inputRowEncoder = _InputRowEncoder(fieldsMeta=inputFieldsMeta)

        # If the checkpoint was incremental, feed the cached data into the model
        for inputSample in self._inputSamplesSinceLastFullCheckpoint:
            # Convert a flat input sample into a format that is consumable by an OPF
            # model
            self._inputRowEncoder.appendRecord(inputSample)

            # Infer
            self._model.run(self._inputRowEncoder.getNextRecordDict())

    def saveModel(self, currentRunBatchIDSet, currentRunInputSamples):
        """
    :param currentRunBatchIDSet: a set of batch ids to be saved in model
      checkpoint attributes

    :param currentRunInputSamples: a sequence of model input data sample objects
      for incremental checkpoint; will be saved in checkpoint attributes if an
      incremental checkpoint is performed.
    """
        if self._model is not None:
            self._modelCheckpointBatchIDSetCache = currentRunBatchIDSet.copy()

            if (not self._hasCheckpoint
                    or (len(self._inputSamplesSinceLastFullCheckpoint) +
                        len(currentRunInputSamples)) >
                    self._MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS):
                # Perform a full checkpoint
                self._inputSamplesSinceLastFullCheckpointCache = []

                self._checkpointMgr.save(
                    modelID=self._modelID,
                    model=self._model,
                    attributes={
                        self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
                        list(self._modelCheckpointBatchIDSetCache)
                    })

                self._hasCheckpoint = True
            else:
                # Perform an incremental checkpoint
                self._inputSamplesSinceLastFullCheckpoint.extend(
                    currentRunInputSamples)
                attributes = {
                    self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
                    list(self._modelCheckpointBatchIDSetCache),
                    self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME:
                    self._encodeDataSamples(
                        self._inputSamplesSinceLastFullCheckpoint)
                }

                self._checkpointMgr.updateCheckpointAttributes(
                    self._modelID, attributes)
  def testModelCheckpointSaveAndLoadSupport(self):
    """ Test saving and loading models """
    checkpointMgr = ModelCheckpointMgr()

    modelID = uuid.uuid1().hex

    # Calling load when the model doesn't exist should raise an
    #  exception
    self.assertRaises(ModelNotFound, checkpointMgr.load, modelID)

    # Create a model that we can save
    originalModel = ModelFactory.create(self._getModelParams("variant1"))

    # Save it
    checkpointMgr.define(modelID, definition=dict(a=1, b=2))

    # Attempting to load model that hasn't been checkpointed should raise
    # ModelNotFound
    self.assertRaises(ModelNotFound, checkpointMgr.load, modelID)

    # Save the checkpoint
    checkpointMgr.save(modelID, originalModel, attributes="attributes1")

    # Load the model from the saved checkpoint
    loadedModel = checkpointMgr.load(modelID)
    self.assertEqual(str(loadedModel.getFieldInfo()),
                     str(originalModel.getFieldInfo()))
    del loadedModel
    del originalModel

    self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                     "attributes1")

    # Make sure we can replace an existing model
    model2 = ModelFactory.create(self._getModelParams("variant2"))
    checkpointMgr.save(modelID, model2, attributes="attributes2")
    model = checkpointMgr.load(modelID)
    self.assertEqual(str(model.getFieldInfo()), str(model2.getFieldInfo()))

    self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                     "attributes2")

    model3 = ModelFactory.create(self._getModelParams("variant3"))
    checkpointMgr.save(modelID, model3, attributes="attributes3")
    model = checkpointMgr.load(modelID)
    self.assertEqual(str(model.getFieldInfo()), str(model3.getFieldInfo()))

    self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                     "attributes3")

    # Simulate a failure during checkpointing and make sure it doesn't mess
    #  up our already existing checkpoint
    try:
      checkpointMgr.save(modelID, "InvalidModel", attributes="attributes4")
    except AttributeError:
      pass
    model = checkpointMgr.load(modelID)
    self.assertEqual(str(model.getFieldInfo()), str(model3.getFieldInfo()))

    self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                     "attributes3")
    def testModelCheckpointSaveAndLoadSupport(self):
        """ Test saving and loading models """
        checkpointMgr = ModelCheckpointMgr()

        modelID = uuid.uuid1().hex

        # Calling load when the model doesn't exist should raise an
        #  exception
        self.assertRaises(ModelNotFound, checkpointMgr.load, modelID)

        # Create a model that we can save
        originalModel = ModelFactory.create(self._getModelParams("variant1"))

        # Save it
        checkpointMgr.define(modelID, definition=dict(a=1, b=2))

        # Attempting to load model that hasn't been checkpointed should raise
        # ModelNotFound
        self.assertRaises(ModelNotFound, checkpointMgr.load, modelID)

        # Save the checkpoint
        checkpointMgr.save(modelID, originalModel, attributes="attributes1")

        # Load the model from the saved checkpoint
        loadedModel = checkpointMgr.load(modelID)
        self.assertEqual(str(loadedModel.getFieldInfo()),
                         str(originalModel.getFieldInfo()))
        del loadedModel
        del originalModel

        self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                         "attributes1")

        # Make sure we can replace an existing model
        model2 = ModelFactory.create(self._getModelParams("variant2"))
        checkpointMgr.save(modelID, model2, attributes="attributes2")
        model = checkpointMgr.load(modelID)
        self.assertEqual(str(model.getFieldInfo()), str(model2.getFieldInfo()))

        self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                         "attributes2")

        model3 = ModelFactory.create(self._getModelParams("variant3"))
        checkpointMgr.save(modelID, model3, attributes="attributes3")
        model = checkpointMgr.load(modelID)
        self.assertEqual(str(model.getFieldInfo()), str(model3.getFieldInfo()))

        self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                         "attributes3")

        # Simulate a failure during checkpointing and make sure it doesn't mess
        #  up our already existing checkpoint
        try:
            checkpointMgr.save(modelID,
                               "InvalidModel",
                               attributes="attributes4")
        except AttributeError:
            pass
        model = checkpointMgr.load(modelID)
        self.assertEqual(str(model.getFieldInfo()), str(model3.getFieldInfo()))

        self.assertEqual(checkpointMgr.loadCheckpointAttributes(modelID),
                         "attributes3")
Ejemplo n.º 8
0
class _ModelArchiver(object):
  """ Helper class for loading/creating and checkpointing model
  """

  # Name of the attribute that is stored as an integral component of the
  # checkpoint. This attribute contains a list of batchIDs that were processed
  # since the previous checkpoint.
  _BATCH_IDS_CHECKPOINT_ATTR_NAME = "batchIDs"

  # Name of the attribute that is stored as an integral component of the
  # checkpoint. It contains a list of ModelInputRow objects processed since
  # the last full checkpoint for use in preparing the last incremental
  # model checkpoint for new input. The value is in pickle string format.
  _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME = "incrementalInputSamples"

  _MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS = 100


  def __init__(self, modelID):
    """
    :param modelID: model ID; string
    """
    self._modelID = modelID

    # The model object from OPF ModelFactory; set up by the loadModel() method
    self._model = None

    # The _InputRowEncoder object; set up by the loadModel() method
    self._inputRowEncoder = None

    self._checkpointMgr = ModelCheckpointMgr()

    # True if model was loaded from existing checkpoint or after a full
    # checkpoint is made; False initially if model was created from
    # params. This informs us whether an incremental checkpoint is possible.
    self._hasCheckpoint = False

    self._modelCheckpointBatchIDSetCache = None

    # Input data samples that have accumulated since last full checkpoint
    self._inputSamplesSinceLastFullCheckpointCache = None


  @property
  def model(self):
    """ An OPF Model object or None if not loaded yet """
    return self._model


  @property
  def inputRowEncoder(self):
    """ An _InputRowEncoder object or None if model not loaded yet """
    return self._inputRowEncoder


  @property
  def modelCheckpointBatchIDSet(self):
    """ A sequence of input batch identifiers associated with current
    model checkpoint
    """
    if self._modelCheckpointBatchIDSetCache is None:
      self._loadCheckpointAttributes()
    return self._modelCheckpointBatchIDSetCache


  @property
  def checkpointMgr(self):
    return self._checkpointMgr


  @property
  def _inputSamplesSinceLastFullCheckpoint(self):
    if self._inputSamplesSinceLastFullCheckpointCache is None:
      self._loadCheckpointAttributes()
    return self._inputSamplesSinceLastFullCheckpointCache


  @_inputSamplesSinceLastFullCheckpoint.setter
  def _inputSamplesSinceLastFullCheckpoint(self, value):
    self._inputSamplesSinceLastFullCheckpointCache = value


  @classmethod
  def _encodeDataSamples(cls, dataSamples):
    """
    :param dataSamples: a sequence of data samples to be saved as the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a string encoding of the data samples
    """
    return base64.standard_b64encode(pickle.dumps(dataSamples,
                                                  pickle.HIGHEST_PROTOCOL))


  @classmethod
  def _decodeDataSamples(cls, dataSamples):
    """
    :param dataSamples: string-encoded data samples from the
      _INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME checkpoint attribute

    :returns: a sequence of data samples
    """
    return pickle.loads(base64.standard_b64decode(dataSamples))


  def _loadCheckpointAttributes(self):
    # Load the checkpoint attributes
    try:
      checkpointAttributes = self._checkpointMgr.loadCheckpointAttributes(
        self._modelID)
    except model_checkpoint_mgr.ModelNotFound:
      self._modelCheckpointBatchIDSetCache = set()
      self._inputSamplesSinceLastFullCheckpoint = []
    else:
      self._modelCheckpointBatchIDSetCache = set(
        checkpointAttributes[self._BATCH_IDS_CHECKPOINT_ATTR_NAME])

      inputSamples = checkpointAttributes.get(
        self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME)
      if inputSamples:
        self._inputSamplesSinceLastFullCheckpoint = self._decodeDataSamples(
          inputSamples)
      else:
        self._inputSamplesSinceLastFullCheckpoint = []


  def loadModel(self):
    """ Load the model and construct the input row encoder. On success,
    the loaded model may be accessed via the `model` attribute

    :raises: model_checkpoint_mgr.ModelNotFound
    """
    if self._model is not None:
      return

    modelDefinition = None

    # Load the model
    try:
      self._model = self._checkpointMgr.load(self._modelID)
      self._hasCheckpoint = True
    except model_checkpoint_mgr.ModelNotFound:
      # So, we didn't have a checkpoint... try to create our model from model
      # definition params
      self._hasCheckpoint = False
      try:
        modelDefinition = self._checkpointMgr.loadModelDefinition(self._modelID)
      except model_checkpoint_mgr.ModelNotFound:
        raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                                msg="modelID=%s not found" % (self._modelID))
      else:
        modelParams = modelDefinition["modelParams"]

      # TODO: when creating the model from params, do we need to call
      #   its model.setFieldStatistics() method? And where will the
      #   fieldStats come from, anyway?
      self._model = ModelFactory.create(
        modelConfig=modelParams["modelConfig"])
      self._model.enableLearning()
      self._model.enableInference(modelParams["inferenceArgs"])


    # Construct the object for converting a flat input row into a format
    # that is consumable by an OPF model
    try:
      if modelDefinition is None:
        modelDefinition = self._checkpointMgr.loadModelDefinition(
          self._modelID)
    except model_checkpoint_mgr.ModelNotFound:
      raise _ModelRunnerError(errno=htmengineerrno.ERR_NO_SUCH_MODEL,
                              msg="modelID=%s not found" % (self._modelID))
    else:
      inputSchema = modelDefinition["inputSchema"]

    # Convert it to a sequence of FieldMetaInfo instances
    # NOTE: if loadMetaInfo didn't raise, we expect "inputSchema" to be
    #   present; it would be a logic error if it isn't.
    inputFieldsMeta = tuple(FieldMetaInfo(*f) for f in inputSchema)

    self._inputRowEncoder = _InputRowEncoder(fieldsMeta=inputFieldsMeta)

    # If the checkpoint was incremental, feed the cached data into the model
    for inputSample in self._inputSamplesSinceLastFullCheckpoint:
      # Convert a flat input sample into a format that is consumable by an OPF
      # model
      self._inputRowEncoder.appendRecord(inputSample)

      # Infer
      self._model.run(self._inputRowEncoder.getNextRecordDict())


  def saveModel(self, currentRunBatchIDSet, currentRunInputSamples):
    """
    :param currentRunBatchIDSet: a set of batch ids to be saved in model
      checkpoint attributes

    :param currentRunInputSamples: a sequence of model input data sample objects
      for incremental checkpoint; will be saved in checkpoint attributes if an
      incremental checkpoint is performed.
    """
    if self._model is not None:
      self._modelCheckpointBatchIDSetCache = currentRunBatchIDSet.copy()

      if (not self._hasCheckpoint or
          (len(self._inputSamplesSinceLastFullCheckpoint) +
           len(currentRunInputSamples)) >
          self._MAX_INCREMENTAL_CHECKPOINT_DATA_ROWS):
        # Perform a full checkpoint
        self._inputSamplesSinceLastFullCheckpointCache = []

        self._checkpointMgr.save(
          modelID=self._modelID, model=self._model,
          attributes={
            self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
              list(self._modelCheckpointBatchIDSetCache)})

        self._hasCheckpoint = True
      else:
        # Perform an incremental checkpoint
        self._inputSamplesSinceLastFullCheckpoint.extend(currentRunInputSamples)
        attributes = {
          self._BATCH_IDS_CHECKPOINT_ATTR_NAME:
            list(self._modelCheckpointBatchIDSetCache),

          self._INPUT_SAMPLES_SINCE_CHECKPOINT_ATTR_NAME:
            self._encodeDataSamples(self._inputSamplesSinceLastFullCheckpoint)
        }

        self._checkpointMgr.updateCheckpointAttributes(self._modelID,
                                                       attributes)