def load_checkpoint(self, directory=None, filename=None): """Loads model weights and step from a checkpoint on disk. Args: directory: Directory with the checkpoint (self._output_dir by default). filename: Checkpoint file name (model.pkl.gz by default). """ directory = directory or self._output_dir if directory is None: _log('Not loading as both directory and output_dir are None.', stdout=False) return filename = filename or 'model.pkl.gz' path = os.path.join(directory, filename) if not tf.io.gfile.exists(path): _log(f'Not loading as checkpoint file does not exist: {path}.', stdout=False) return d = unpickle_from_file(path, gzip=True) # For large models, load weights from sharded files. if self._use_memory_efficient_trainer: weights = [] n_shards = d['flat_weights'] # We store the number of shards in d here. for i in range(n_shards): w = unpickle_from_file(path + '.shard%d' % i, gzip=True) w = self._from_bits(w) # bit-casting may put w on accelerator, go back weights.extend([tl.on_cpu(x) for x in w]) d['flat_weights'] = weights else: d['flat_weights'] = self._from_bits(d['flat_weights']) self._step = d['step'] if 'slots' in d: if len(self._tasks) != 1: raise ValueError( 'Can\'t load a single-task checkpoint into a multitask Loop.' ) d['slots_per_task'] = [d['slots']] if self._use_memory_efficient_trainer: for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']): trainer.slots = slots else: for (task, slots) in zip(self._tasks, d['slots_per_task']): task.optimizer.slots = slots # This is self._model.init_from_file but optimized to not re-read. input_signature = d['input_signature'] weights_and_state_sig = self._model.weights_and_state_signature( input_signature) weights, state = tl.unflatten_weights_and_state( d['flat_weights'], d['flat_state'], weights_and_state_sig) self._model.state = state self._model.weights = weights self._eval_model.weights = self._model.weights # Restore eval model state; note: it's not always the same as train state. if 'flat_eval_state' in d: flat_eval_state = d['flat_eval_state'] else: # It wasn't saved in old checkpoints; remove this branch once ported. flat_eval_state = d['flat_state'] _, eval_state = tl.unflatten_weights_and_state( d['flat_weights'], flat_eval_state, weights_and_state_sig) self._eval_model.state = eval_state
def trainer_state_from_dict(trainer_state_dict, model): """Given the trainer state dictionary, returns `TrainerState`.""" # TODO(afrozm): This becomes simpler if OptState is flattened into # TrainerState. step = trainer_state_dict['step'] history = trainer_state_dict['history'] input_signature = trainer_state_dict['input_signature'] weights_and_state_sig = model.weights_and_state_signature(input_signature) weights, model_state = tl.unflatten_weights_and_state( trainer_state_dict['flat_weights'], trainer_state_dict['flat_state'], weights_and_state_sig) opt_state = OptState( weights=weights, slots=trainer_state_dict['slots'], opt_params=trainer_state_dict['opt_params']) return TrainerState(step=step, opt_state=OptState(*opt_state), history=history, model_state=model_state)
def test_flat_weights_and_state(self): model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup())) sample_input_signature = shapes.signature(np.zeros((2, 3))) model.init(sample_input_signature) flat_weights, flat_state = tl.flatten_weights_and_state( model.weights, model.state) # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. # So after making them flat, there are 4 trainable weights. self.assertLen(flat_weights, 4) self.assertEmpty(flat_state) model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7)) sig = model2.weights_and_state_signature(sample_input_signature) weights2, state2 = tl.unflatten_weights_and_state( flat_weights, flat_state, sig) model2.weights = weights2 model2.state = state2 self.assertLen(model2.weights, 3) self.assertEqual(model.weights[1], model2.weights[0]) self.assertEqual(model.weights[2][0], model2.weights[2])