Esempio n. 1
0
  def load_checkpoint(self, directory=None, filename=None):
    """Loads model weights and step from a checkpoint on disk.

    Args:
      directory: Directory with the checkpoint (self._output_dir by default).
      filename: Checkpoint file name (model.pkl.gz by default).
    """
    directory = directory or self._output_dir
    if directory is None:
      _log('Not loading as both directory and output_dir are None.',
           stdout=False)
      return
    filename = filename or 'model.pkl.gz'
    path = os.path.join(directory, filename)
    if not tf.io.gfile.exists(path):
      _log(f'Not loading as checkpoint file does not exist: {path}.',
           stdout=False)
      return
    d = unpickle_from_file(path, gzip=True)
    # For large models, load weights from sharded files.
    if self._use_memory_efficient_trainer:
      weights = []
      n_shards = d['flat_weights']  # We store the number of shards in d here.
      for i in range(n_shards):
        w = unpickle_from_file(path + '.shard%d' % i, gzip=True)
        w = self._from_bits(w)  # bit-casting may put w on accelerator, go back
        weights.extend([tl.on_cpu(x) for x in w])
      d['flat_weights'] = weights
    else:
      d['flat_weights'] = self._from_bits(d['flat_weights'])
    self._step = d['step']
    if 'slots' in d:
      if len(self._tasks) != 1:
        raise ValueError(
            'Can\'t load a single-task checkpoint into a multitask Loop.'
        )
      d['slots_per_task'] = [d['slots']]
    if self._use_memory_efficient_trainer:
      for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']):
        trainer.slots = slots
    else:
      for (task, slots) in zip(self._tasks, d['slots_per_task']):
        task.optimizer.slots = slots
    # This is self._model.init_from_file but optimized to not re-read.
    input_signature = d['input_signature']
    weights_and_state_sig = self._model.weights_and_state_signature(
        input_signature)
    weights, state = tl.unflatten_weights_and_state(
        d['flat_weights'], d['flat_state'], weights_and_state_sig)
    self._model.state = state
    self._model.weights = weights
    self._eval_model.weights = self._model.weights
    # Restore eval model state; note: it's not always the same as train state.
    if 'flat_eval_state' in d:
      flat_eval_state = d['flat_eval_state']
    else:  # It wasn't saved in old checkpoints; remove this branch once ported.
      flat_eval_state = d['flat_state']
    _, eval_state = tl.unflatten_weights_and_state(
        d['flat_weights'], flat_eval_state, weights_and_state_sig)
    self._eval_model.state = eval_state
Esempio n. 2
0
def trainer_state_from_dict(trainer_state_dict, model):
  """Given the trainer state dictionary, returns `TrainerState`."""
  # TODO(afrozm): This becomes simpler if OptState is flattened into
  # TrainerState.
  step = trainer_state_dict['step']
  history = trainer_state_dict['history']
  input_signature = trainer_state_dict['input_signature']
  weights_and_state_sig = model.weights_and_state_signature(input_signature)
  weights, model_state = tl.unflatten_weights_and_state(
      trainer_state_dict['flat_weights'], trainer_state_dict['flat_state'],
      weights_and_state_sig)
  opt_state = OptState(
      weights=weights,
      slots=trainer_state_dict['slots'],
      opt_params=trainer_state_dict['opt_params'])
  return TrainerState(step=step, opt_state=OptState(*opt_state),
                      history=history, model_state=model_state)
Esempio n. 3
0
 def test_flat_weights_and_state(self):
   model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup()))
   sample_input_signature = shapes.signature(np.zeros((2, 3)))
   model.init(sample_input_signature)
   flat_weights, flat_state = tl.flatten_weights_and_state(
       model.weights, model.state)
   # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers.
   # So after making them flat, there are 4 trainable weights.
   self.assertLen(flat_weights, 4)
   self.assertEmpty(flat_state)
   model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7))
   sig = model2.weights_and_state_signature(sample_input_signature)
   weights2, state2 = tl.unflatten_weights_and_state(
       flat_weights, flat_state, sig)
   model2.weights = weights2
   model2.state = state2
   self.assertLen(model2.weights, 3)
   self.assertEqual(model.weights[1], model2.weights[0])
   self.assertEqual(model.weights[2][0], model2.weights[2])