Example #1
0
 def save_checkpoint(self):
   """Saves checkpoint to disk for the current training step."""
   if not self.is_chief:
     _log('Did not save checkpoint as we are not chief.')
     return
   if self._output_dir is None:
     _log('Did not save checkpoint as output_dir is None')
     return
   weights = self._model.weights
   state = self._model.state
   if self._use_memory_efficient_trainer:
     slots_per_task = [trainer.slots for trainer in self._trainer_per_task]
   else:
     slots_per_task = tuple(task.optimizer.slots for task in self._tasks)
   # We only need the input signature for the body, not for the loss layers.
   # That part is the same across tasks - take it from the first one.
   input_signature = self._batch_signature[:self._model.n_in]
   flat_weights, flat_state = tl.flatten_weights_and_state(weights, state)
   _, flat_eval_state = tl.flatten_weights_and_state(
       weights, self._eval_model.state)
   d = {
       'step': self.step,
       'flat_weights': flat_weights,
       'flat_state': flat_state,
       'flat_eval_state': flat_eval_state,
       'slots_per_task': slots_per_task,
       'input_signature': input_signature,
       'version_timestamp': 'Sep-17-2020'  # To update in the future if needed.
   }
   ckpt_file = os.path.join(self._output_dir, 'model.pkl.gz')
   pickle_to_file(d, ckpt_file, gzip=True)
Example #2
0
 def save_checkpoint(self, permanent=False):
     """Saves checkpoint to disk for the current training step."""
     if not self.is_chief:
         _log('Did not save checkpoint as we are not chief.')
         return
     if self._output_dir is None:
         _log('Did not save checkpoint as output_dir is None')
         return
     if permanent:
         filename = 'model_{}.pkl.gz'.format(self.step)
     else:
         filename = 'model.pkl.gz'
     ckpt_file = os.path.join(self._output_dir, filename)
     _log('Saving checkpoint to %s.' % ckpt_file, stdout=False)
     weights = self._model.weights
     state = self._model.state
     if self._use_memory_efficient_trainer:
         slots_per_task = [
             trainer.slots for trainer in self._trainer_per_task
         ]
     else:
         slots_per_task = tuple(task.optimizer.slots
                                for task in self._tasks)
     # We only need the input signature for the body, not for the loss layers.
     # That part is the same across tasks - take it from the first one.
     input_signature = self._batch_signature[:self._model.n_in]
     flat_weights, flat_state = tl.flatten_weights_and_state(weights, state)
     _, flat_eval_state = tl.flatten_weights_and_state(
         weights, self._eval_model.state)
     if self._use_memory_efficient_trainer:
         sharded_weights_len = self._save_weights_sharded(
             flat_weights, ckpt_file)
         # In the main dict we just save the number of shards in place of weights.
         weights_in_dict = sharded_weights_len
     else:
         weights_in_dict = self._to_bits(flat_weights)
     d = {
         'step': self.step,
         'flat_weights': weights_in_dict,
         'flat_state': flat_state,
         'flat_eval_state': flat_eval_state,
         'slots_per_task': slots_per_task,
         'input_signature': input_signature,
         'version_timestamp':
         'Oct-28-2020'  # To update in the future if needed.
     }
     pickle_to_file(d, ckpt_file, gzip=True)
     # Move sharded files to non-tmp files after all is saved.
     if self._use_memory_efficient_trainer:
         for i in range(weights_in_dict):
             fname = ckpt_file + '.shard%d' % i
             tf.io.gfile.rename(fname + '.tmp', fname, overwrite=True)
     _log('Checkpoint saved in %s.' % ckpt_file, stdout=False)
Example #3
0
def make_trainer_state_dict(step, opt_state, history, model_state,
                            input_signature):
    """Creates a trainer state dictionary to save to disk.

  Args:
    step: int, a step number
    opt_state: OptState namedtuple
    history: `trax.history.History`, the history object.
    model_state: A nested structure of the model state.
    input_signature: signature of model inputs.

  Returns:
    A dictionary with the fields of TrainerState and OptState flattened.
  """
    flat_weights, flat_state = tl.flatten_weights_and_state(
        opt_state.weights, model_state)
    return {
        'step': step,
        'flat_weights': flat_weights,
        'slots': opt_state.slots,
        'opt_params': opt_state.opt_params,
        'history': history,
        'flat_state': flat_state,
        'input_signature': input_signature,
        'version_timestamp':
        'Jun-18-2020'  # To update in the future if needed.
    }
Example #4
0
  def save_checkpoint(self, weights=None, state=None, slots=None):
    """Saves checkpoint to disk for the current training step.

    Args:
      weights: Weights from model being trained.
      state: State (non-weight parameters) from model being trained.
      slots: Updatable weights for the optimizer in this training loop.
    """
    if not self.is_chief:
      return
    if self._output_dir is None:
      _log('Did not save checkpoint as output_dir is None', stdout=False)
      return
    weights = self._model_in_training.weights if weights is None else weights
    state = self._model_in_training.state if state is None else state
    slots = self._task.optimizer.slots if slots is None else slots
    flat_weights, flat_state = tl.flatten_weights_and_state(weights, state)
    d = {
        'step': self.step,
        'flat_weights': flat_weights,
        'flat_state': flat_state,
        'slots': slots,
        'input_signature': self._batch_signature,
        'version_timestamp': 'Jun-29-2020'  # To update in the future if needed.
    }
    ckpt_file = os.path.join(self._output_dir, 'model.pkl.gz')
    pickle_to_file(d, ckpt_file, gzip=True)
Example #5
0
 def test_flat_weights_and_state(self):
   model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup()))
   sample_input_signature = shapes.signature(np.zeros((2, 3)))
   model.init(sample_input_signature)
   flat_weights, flat_state = tl.flatten_weights_and_state(
       model.weights, model.state)
   # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers.
   # So after making them flat, there are 4 trainable weights.
   self.assertLen(flat_weights, 4)
   self.assertEmpty(flat_state)
   model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7))
   sig = model2.weights_and_state_signature(sample_input_signature)
   weights2, state2 = tl.unflatten_weights_and_state(
       flat_weights, flat_state, sig)
   model2.weights = weights2
   model2.state = state2
   self.assertLen(model2.weights, 3)
   self.assertEqual(model.weights[1], model2.weights[0])
   self.assertEqual(model.weights[2][0], model2.weights[2])