Ejemplo n.º 1
0
 def on_epoch_end(self, data: Data) -> None:
     if self.system.epoch_idx % self.frequency == 0:
         # Save all models and optimizer state
         for model in self.system.network.models:
             save_model(model, save_dir=self.directory, save_optimizer=True)
         # Save system state
         self.system.save_state(
             json_path=os.path.join(self.directory, self.system_file))
         print("FastEstimator-RestoreWizard: Saved milestones to {}".format(
             self.directory))
Ejemplo n.º 2
0
    def save_state(self, save_dir: str) -> None:
        """Load training state.

        Args:
            save_dir: The directory into which to save the state
        """
        os.makedirs(save_dir, exist_ok=True)
        # Start with the high-level info. We could use pickle for this but having it human readable is nice.
        state = {
            key: value
            for key, value in self.__dict__.items() if is_restorable(value)[0]
        }
        with open(os.path.join(save_dir, 'system.json'), 'w') as fp:
            json.dump(state, fp, indent=4)
        # Save all of the models / optimizer states
        for model in self.network.models:
            save_model(model, save_dir=save_dir, save_optimizer=True)
        # Save everything else
        objects = {
            'summary':
            self.summary,
            'custom_graphs':
            self.custom_graphs,
            'traces': [
                trace.__getstate__() if hasattr(trace, '__getstate__') else {}
                for trace in self.traces
            ],
            'tops': [
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.network.ops
            ],
            'pops': [
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.network.postprocessing
            ],
            'nops': [
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.pipeline.ops
            ],
            'ds': {
                mode: {
                    key: value.__getstate__()
                    for key, value in ds.items()
                    if hasattr(value, '__getstate__')
                }
                for mode, ds in self.pipeline.data.items()
            }
        }
        with open(os.path.join(save_dir, 'objects.pkl'), 'wb') as file:
            # We need to use a custom pickler here to handle MirroredStrategy, which will show up inside of tf
            # MirroredVariables in multi-gpu systems.
            p = pickle.Pickler(file)
            p.dispatch_table = copyreg.dispatch_table.copy()
            p.dispatch_table[MirroredStrategy] = pickle_mirroredstrategy
            p.dump(objects)
Ejemplo n.º 3
0
    def save_state(self, save_dir: str) -> None:
        """Load training state.

        Args:
            save_dir: The directory into which to save the state
        """
        os.makedirs(save_dir, exist_ok=True)
        # Start with the high-level info. We could use pickle for this but having it human readable is nice.
        state = {
            key: value
            for key, value in self.__dict__.items() if is_restorable(value)[0]
        }
        with open(os.path.join(save_dir, 'system.json'), 'w') as fp:
            json.dump(state, fp, indent=4)
        # Save all of the models / optimizer states
        for model in self.network.models:
            save_model(model, save_dir=save_dir, save_optimizer=True)
        # Save the Summary object
        with open(os.path.join(save_dir, 'summary.pkl'), 'wb') as file:
            pickle.dump(self.summary, file)
        # Save the Traces
        with open(os.path.join(save_dir, 'traces.pkl'), 'wb') as file:
            pickle.dump([
                trace.__getstate__() if hasattr(trace, '__getstate__') else {}
                for trace in self.traces
            ], file)
        # Save the TensorOps
        with open(os.path.join(save_dir, 'tops.pkl'), 'wb') as file:
            pickle.dump([
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.network.ops
            ], file)
        # Save the NumpyOps
        with open(os.path.join(save_dir, 'nops.pkl'), 'wb') as file:
            pickle.dump([
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.pipeline.ops
            ], file)
        # Save the Datasets
        with open(os.path.join(save_dir, 'ds.pkl'), 'wb') as file:
            pickle.dump(
                {
                    key: value.__getstate__()
                    for key, value in self.pipeline.data.items()
                    if hasattr(value, '__getstate__')
                }, file)
Ejemplo n.º 4
0
 def on_epoch_end(self, data: Data) -> None:
     # No model will be saved when save_dir is None, which makes smoke test easier.
     if self.save_dir and self.system.epoch_idx % self.frequency == 0:
         model_name = "{}_epoch_{}".format(self.model.model_name,
                                           self.system.epoch_idx)
         model_path = save_model(self.model, self.save_dir, model_name)
         print("FastEstimator-ModelSaver: Saved model to {}".format(
             model_path))
 def test_torch_model_on_begin(self):
     restore_wizard = RestoreWizard(directory=self.system_json_path)
     restore_wizard.system = sample_system_object_torch()
     # save state
     for model in restore_wizard.system.network.models:
         save_model(model,
                    save_dir=restore_wizard.directory,
                    save_optimizer=True)
     restore_wizard.system.save_state(json_path=os.path.join(
         restore_wizard.directory, restore_wizard.system_file))
     restore_wizard.on_begin(data=self.data)
     with self.subTest('Check the restore files directory'):
         self.assertEqual(restore_wizard.directory, self.system_json_path)
     with self.subTest('check data dictionary'):
         self.assertEqual(self.data['epoch'], 0)
     if os.path.exists(self.system_json_path):
         shutil.rmtree(self.system_json_path)
Ejemplo n.º 6
0
 def on_epoch_end(self, data: Data) -> None:
     # No model will be saved when save_dir is None, which makes smoke test easier.
     if self.save_dir and self.system.epoch_idx % self.frequency == 0:
         model_name = "{}_epoch_{}".format(self.model.model_name, self.system.epoch_idx)
         model_path = save_model(self.model, self.save_dir, model_name)
         print("FastEstimator-ModelSaver: Saved model to {}".format(model_path))
         rm_path = self.file_queue[self.file_queue.maxlen - 1] if self.file_queue.maxlen else None
         if rm_path:
             os.remove(rm_path)
             print("FastEstimator-ModelSaver: Removed model {} due to file number exceeding max_to_keep".format(
                 rm_path))
         self.file_queue.appendleft(model_path)
Ejemplo n.º 7
0
 def on_epoch_end(self, data: Data) -> None:
     if self.monitor_op(data[self.metric], self.best):
         self.best = data[self.metric]
         self.since_best = 0
         if self.save_dir:
             self.model_path = save_model(self.model, self.save_dir,
                                          self.model_name)
             print("FastEstimator-BestModelSaver: Saved model to {}".format(
                 self.model_path))
     else:
         self.since_best += 1
     data.write_with_log(self.outputs[0], self.since_best)
     data.write_with_log(self.outputs[1], self.best)