Exemple #1
0
 def __getstate__(self) -> Dict[str, List[Dict[Any, Any]]]:
     return {
         'repeat_list': [
             elem if is_restorable(elem)[0] else
             elem.__getstate__() if hasattr(elem, '__getstate__') else {}
             for elem in self.repeat_list
         ]
     }
Exemple #2
0
 def __getstate__(self) -> Dict[str, Dict[int, Dict[Any, Any]]]:
     return {
         'epoch_dict': {
             key: elem if is_restorable(elem)[0] else
             elem.__getstate__() if hasattr(elem, '__getstate__') else {}
             for key, elem in self.epoch_dict.items()
         }
     }
Exemple #3
0
    def save_state(self, save_dir: str) -> None:
        """Load training state.

        Args:
            save_dir: The directory into which to save the state
        """
        os.makedirs(save_dir, exist_ok=True)
        # Start with the high-level info. We could use pickle for this but having it human readable is nice.
        state = {
            key: value
            for key, value in self.__dict__.items() if is_restorable(value)[0]
        }
        with open(os.path.join(save_dir, 'system.json'), 'w') as fp:
            json.dump(state, fp, indent=4)
        # Save all of the models / optimizer states
        for model in self.network.models:
            save_model(model, save_dir=save_dir, save_optimizer=True)
        # Save everything else
        objects = {
            'summary':
            self.summary,
            'custom_graphs':
            self.custom_graphs,
            'traces': [
                trace.__getstate__() if hasattr(trace, '__getstate__') else {}
                for trace in self.traces
            ],
            'tops': [
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.network.ops
            ],
            'pops': [
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.network.postprocessing
            ],
            'nops': [
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.pipeline.ops
            ],
            'ds': {
                mode: {
                    key: value.__getstate__()
                    for key, value in ds.items()
                    if hasattr(value, '__getstate__')
                }
                for mode, ds in self.pipeline.data.items()
            }
        }
        with open(os.path.join(save_dir, 'objects.pkl'), 'wb') as file:
            # We need to use a custom pickler here to handle MirroredStrategy, which will show up inside of tf
            # MirroredVariables in multi-gpu systems.
            p = pickle.Pickler(file)
            p.dispatch_table = copyreg.dispatch_table.copy()
            p.dispatch_table[MirroredStrategy] = pickle_mirroredstrategy
            p.dump(objects)
Exemple #4
0
    def save_state(self, save_dir: str) -> None:
        """Load training state.

        Args:
            save_dir: The directory into which to save the state
        """
        os.makedirs(save_dir, exist_ok=True)
        # Start with the high-level info. We could use pickle for this but having it human readable is nice.
        state = {
            key: value
            for key, value in self.__dict__.items() if is_restorable(value)[0]
        }
        with open(os.path.join(save_dir, 'system.json'), 'w') as fp:
            json.dump(state, fp, indent=4)
        # Save all of the models / optimizer states
        for model in self.network.models:
            save_model(model, save_dir=save_dir, save_optimizer=True)
        # Save the Summary object
        with open(os.path.join(save_dir, 'summary.pkl'), 'wb') as file:
            pickle.dump(self.summary, file)
        # Save the Traces
        with open(os.path.join(save_dir, 'traces.pkl'), 'wb') as file:
            pickle.dump([
                trace.__getstate__() if hasattr(trace, '__getstate__') else {}
                for trace in self.traces
            ], file)
        # Save the TensorOps
        with open(os.path.join(save_dir, 'tops.pkl'), 'wb') as file:
            pickle.dump([
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.network.ops
            ], file)
        # Save the NumpyOps
        with open(os.path.join(save_dir, 'nops.pkl'), 'wb') as file:
            pickle.dump([
                op.__getstate__() if hasattr(op, '__getstate__') else {}
                for op in self.pipeline.ops
            ], file)
        # Save the Datasets
        with open(os.path.join(save_dir, 'ds.pkl'), 'wb') as file:
            pickle.dump(
                {
                    key: value.__getstate__()
                    for key, value in self.pipeline.data.items()
                    if hasattr(value, '__getstate__')
                }, file)