def save_checkpoint(self): """Saves checkpoint to disk for the current training step.""" if not self.is_chief: _log('Did not save checkpoint as we are not chief.') return if self._output_dir is None: _log('Did not save checkpoint as output_dir is None') return weights = self._model.weights state = self._model.state if self._use_memory_efficient_trainer: slots_per_task = [trainer.slots for trainer in self._trainer_per_task] else: slots_per_task = tuple(task.optimizer.slots for task in self._tasks) # We only need the input signature for the body, not for the loss layers. # That part is the same across tasks - take it from the first one. input_signature = self._batch_signature[:self._model.n_in] flat_weights, flat_state = tl.flatten_weights_and_state(weights, state) _, flat_eval_state = tl.flatten_weights_and_state( weights, self._eval_model.state) d = { 'step': self.step, 'flat_weights': flat_weights, 'flat_state': flat_state, 'flat_eval_state': flat_eval_state, 'slots_per_task': slots_per_task, 'input_signature': input_signature, 'version_timestamp': 'Sep-17-2020' # To update in the future if needed. } ckpt_file = os.path.join(self._output_dir, 'model.pkl.gz') pickle_to_file(d, ckpt_file, gzip=True)
def save_checkpoint(self, permanent=False): """Saves checkpoint to disk for the current training step.""" if not self.is_chief: _log('Did not save checkpoint as we are not chief.') return if self._output_dir is None: _log('Did not save checkpoint as output_dir is None') return if permanent: filename = 'model_{}.pkl.gz'.format(self.step) else: filename = 'model.pkl.gz' ckpt_file = os.path.join(self._output_dir, filename) _log('Saving checkpoint to %s.' % ckpt_file, stdout=False) weights = self._model.weights state = self._model.state if self._use_memory_efficient_trainer: slots_per_task = [ trainer.slots for trainer in self._trainer_per_task ] else: slots_per_task = tuple(task.optimizer.slots for task in self._tasks) # We only need the input signature for the body, not for the loss layers. # That part is the same across tasks - take it from the first one. input_signature = self._batch_signature[:self._model.n_in] flat_weights, flat_state = tl.flatten_weights_and_state(weights, state) _, flat_eval_state = tl.flatten_weights_and_state( weights, self._eval_model.state) if self._use_memory_efficient_trainer: sharded_weights_len = self._save_weights_sharded( flat_weights, ckpt_file) # In the main dict we just save the number of shards in place of weights. weights_in_dict = sharded_weights_len else: weights_in_dict = self._to_bits(flat_weights) d = { 'step': self.step, 'flat_weights': weights_in_dict, 'flat_state': flat_state, 'flat_eval_state': flat_eval_state, 'slots_per_task': slots_per_task, 'input_signature': input_signature, 'version_timestamp': 'Oct-28-2020' # To update in the future if needed. } pickle_to_file(d, ckpt_file, gzip=True) # Move sharded files to non-tmp files after all is saved. if self._use_memory_efficient_trainer: for i in range(weights_in_dict): fname = ckpt_file + '.shard%d' % i tf.io.gfile.rename(fname + '.tmp', fname, overwrite=True) _log('Checkpoint saved in %s.' % ckpt_file, stdout=False)
def make_trainer_state_dict(step, opt_state, history, model_state, input_signature): """Creates a trainer state dictionary to save to disk. Args: step: int, a step number opt_state: OptState namedtuple history: `trax.history.History`, the history object. model_state: A nested structure of the model state. input_signature: signature of model inputs. Returns: A dictionary with the fields of TrainerState and OptState flattened. """ flat_weights, flat_state = tl.flatten_weights_and_state( opt_state.weights, model_state) return { 'step': step, 'flat_weights': flat_weights, 'slots': opt_state.slots, 'opt_params': opt_state.opt_params, 'history': history, 'flat_state': flat_state, 'input_signature': input_signature, 'version_timestamp': 'Jun-18-2020' # To update in the future if needed. }
def save_checkpoint(self, weights=None, state=None, slots=None): """Saves checkpoint to disk for the current training step. Args: weights: Weights from model being trained. state: State (non-weight parameters) from model being trained. slots: Updatable weights for the optimizer in this training loop. """ if not self.is_chief: return if self._output_dir is None: _log('Did not save checkpoint as output_dir is None', stdout=False) return weights = self._model_in_training.weights if weights is None else weights state = self._model_in_training.state if state is None else state slots = self._task.optimizer.slots if slots is None else slots flat_weights, flat_state = tl.flatten_weights_and_state(weights, state) d = { 'step': self.step, 'flat_weights': flat_weights, 'flat_state': flat_state, 'slots': slots, 'input_signature': self._batch_signature, 'version_timestamp': 'Jun-29-2020' # To update in the future if needed. } ckpt_file = os.path.join(self._output_dir, 'model.pkl.gz') pickle_to_file(d, ckpt_file, gzip=True)
def test_flat_weights_and_state(self): model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup())) sample_input_signature = shapes.signature(np.zeros((2, 3))) model.init(sample_input_signature) flat_weights, flat_state = tl.flatten_weights_and_state( model.weights, model.state) # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. # So after making them flat, there are 4 trainable weights. self.assertLen(flat_weights, 4) self.assertEmpty(flat_state) model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7)) sig = model2.weights_and_state_signature(sample_input_signature) weights2, state2 = tl.unflatten_weights_and_state( flat_weights, flat_state, sig) model2.weights = weights2 model2.state = state2 self.assertLen(model2.weights, 3) self.assertEqual(model.weights[1], model2.weights[0]) self.assertEqual(model.weights[2][0], model2.weights[2])