Exemplo n.º 1
0
 def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
     """
     Collection of all checkpoints for the network (typically only one checkpoint)
     :param parent_path_suffix: path suffix of the parent of the network
         (e.g. could be name of level manager plus name of agent)
     :return: checkpoint collection for the network
     """
     savers = SaverCollection()
     if not self.distributed_training:
         savers.add(GlobalVariableSaver(self.name))
     return savers
 def collect_savers(self, parent_path_suffix: str) -> SaverCollection:
     """
     Collection of all checkpoints for the network (typically only one checkpoint)
     :param parent_path_suffix: path suffix of the parent of the network
         (e.g. could be name of level manager plus name of agent)
     :return: checkpoint collection for the network
     """
     name = self.name.replace('/', '.')
     savers = SaverCollection(
         ParameterDictSaver(name="{}.{}".format(parent_path_suffix, name),
                            param_dict=self.model.collect_params()))
     if self.ap.task_parameters.export_onnx_graph:
         savers.add(
             OnnxSaver(name="{}.{}.onnx".format(parent_path_suffix, name),
                       model=self.model,
                       input_shapes=self._model_input_shapes()))
     return savers
def test_checkpoint_collection():
    class SaverTest(Saver):
        def __init__(self, path):
            self._path = path
            self._count = 1

        @property
        def path(self):
            return self._path

        def merge(self, other: 'Saver'):
            assert isinstance(other, SaverTest)
            assert self.path == other.path
            self._count += other._count

    # test add
    savers = SaverCollection(SaverTest('123'))
    savers.add(SaverTest('123'))
    savers.add(SaverTest('456'))

    def check_collection(mul):
        paths = ['123', '456']
        for c in savers:
            paths.remove(c.path)
            if c.path == '123':
                assert c._count == 2 * mul
            elif c.path == '456':
                assert c._count == 1 * mul
            else:
                assert False, "invalid path"

    check_collection(1)

    # test update
    savers.update(savers)
    check_collection(2)