def collect_savers(self, parent_path_suffix: str) -> SaverCollection: """ Collection of all checkpoints for the network (typically only one checkpoint) :param parent_path_suffix: path suffix of the parent of the network (e.g. could be name of level manager plus name of agent) :return: checkpoint collection for the network """ savers = SaverCollection() if not self.distributed_training: savers.add(GlobalVariableSaver(self.name)) return savers
def collect_savers(self, parent_path_suffix: str) -> SaverCollection: """ Collection of all checkpoints for the network (typically only one checkpoint) :param parent_path_suffix: path suffix of the parent of the network (e.g. could be name of level manager plus name of agent) :return: checkpoint collection for the network """ name = self.name.replace('/', '.') savers = SaverCollection( ParameterDictSaver(name="{}.{}".format(parent_path_suffix, name), param_dict=self.model.collect_params())) if self.ap.task_parameters.export_onnx_graph: savers.add( OnnxSaver(name="{}.{}.onnx".format(parent_path_suffix, name), model=self.model, input_shapes=self._model_input_shapes())) return savers
def test_checkpoint_collection(): class SaverTest(Saver): def __init__(self, path): self._path = path self._count = 1 @property def path(self): return self._path def merge(self, other: 'Saver'): assert isinstance(other, SaverTest) assert self.path == other.path self._count += other._count # test add savers = SaverCollection(SaverTest('123')) savers.add(SaverTest('123')) savers.add(SaverTest('456')) def check_collection(mul): paths = ['123', '456'] for c in savers: paths.remove(c.path) if c.path == '123': assert c._count == 2 * mul elif c.path == '456': assert c._count == 1 * mul else: assert False, "invalid path" check_collection(1) # test update savers.update(savers) check_collection(2)