Ejemplo n.º 1
0
class Function(Module):
    """Turn a function into a Module by keeping the vars it uses."""
    def __init__(self, f: Callable, vc: VarCollection):
        """Function constructor.

        Args:
            f: the function or the module to represent.
            vc: the VarCollection of variables used by the function.
        """
        if hasattr(f, '__name__'):
            self.vc = VarCollection(
                (f'{{{f.__name__}}}.{k}', v) for k, v in vc.items())
        else:
            self.vc = VarCollection(vc)
        self.__wrapped__ = f

    def __call__(self, *args, **kwargs):
        """Call the the function."""
        return self.__wrapped__(*args, **kwargs)

    def vars(self, scope: str = '') -> VarCollection:
        """Return the VarCollection of the variables used by the function."""
        if scope:
            return VarCollection((scope + k, v) for k, v in self.vc.items())
        return VarCollection(self.vc)

    @staticmethod
    def with_vars(vc: VarCollection):
        """Method to use as decorator in function definitions."""
        def from_function(f: Callable):
            return Function(f, vc)

        return from_function
Ejemplo n.º 2
0
def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection):
    """Saves variables collection into file.

    Args:
        file: filename or python file handle of the file where variables will be saved.
        vc: variables collection which will be saved into file.
    """
    do_close = isinstance(file, str)
    if do_close:
        filename, file = file, open(
            file + '.tmp', 'wb'
        )  # Save to a temporary in case the job is killed while saving.
    data, names, seen = {}, [], set()
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        if v not in seen:
            names.append(k)
            data[str(len(data))] = v.value
            seen.add(v)
    np.savez(file, names=np.array(names), **data)
    if do_close:
        file.close()
        os.rename(
            filename + '.tmp', filename
        )  # Atomic rename to avoid broken file (when killed while saving).
Ejemplo n.º 3
0
def load_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection):
    """Loads values of all variables in the given variables collection from file.

    Values loaded from file will replace old values in the variables collection.
    If variable exists in the file, but does not exist in the variables collection it will be ignored.
    If variable exists in the variables collection, but not found in the file then exception will be raised.

    Args:
        file: filename or python file handle of the input file.
        vc: variables collection which will be loaded from file.

    Raises:
        ValueError: if variable from variables collection is not found in the input file.
    """
    do_close = isinstance(file, str)
    if do_close:
        file = open(file, 'rb')
    data = np.load(file, allow_pickle=False)
    name_index = {k: str(i) for i, k in enumerate(data['names'])}
    name_vars = collections.defaultdict(list)
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        name_vars[v].append(k)
    for v, names in name_vars.items():
        for name in names:
            index = name_index.get(name)
            if index is not None:
                v.assign(jn.array(data[index]))
                break
        else:
            raise ValueError(f'Missing value for variables {names}')
    if do_close:
        file.close()
Ejemplo n.º 4
0
def load_var_collection(file: Union[str, IO[BinaryIO]],
                        vc: VarCollection,
                        renamer: Optional[Renamer] = None):
    """Loads values of all variables in the given variables collection from file.

    Values loaded from file will replace old values in the variables collection.
    If variable exists in the file, but does not exist in the variables collection it will be ignored.
    If variable exists in the variables collection, but not found in the file then exception will be raised.

    Args:
        file: filename or python file handle of the input file.
        vc: variables collection which will be loaded from file.
        renamer: optional renamer to pre-process variables names from the file being read.

    Raises:
        ValueError: if variable from variables collection is not found in the input file.
    """
    renamer = renamer or (lambda x: x)
    do_close = isinstance(file, str)
    if do_close:
        file = open(file, 'rb')
    data = np.load(file, allow_pickle=False)
    name_index = {renamer(k): str(i) for i, k in enumerate(data['names'])}
    var_names = collections.defaultdict(list)
    var_values = {}
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        var_names[id(v)].append(k)
        var_values[id(v)] = v
    misses = []
    used_vars = set()
    for var_id, names in var_names.items():
        v = var_values[var_id]
        for name in names:
            index = name_index.get(name)
            if index is not None:
                used_vars.add(name)
                try:
                    v.assign(jn.array(data[index]))
                except AssertionError as e:
                    raise AssertionError(
                        f'Error when restoring variable {name}: {str(e)}'
                    ) from None
                break
        else:
            misses += names
    if misses:
        not_used = set(name_index.keys()) - used_vars
        raise ValueError(
            f'Missing value for variables currently in the model: {misses}. '
            f'The following variables on disk were not used, '
            f'maybe the missing variable was renamed from one of these: {not_used}.'
        )
    if do_close:
        file.close()
Ejemplo n.º 5
0
    def __init__(self, f: Callable, vc: VarCollection):
        """Function constructor.

        Args:
            f: the function or the module to represent.
            vc: the VarCollection of variables used by the function.
        """
        if hasattr(f, '__name__'):
            self.vc = VarCollection((f'{{{f.__name__}}}{k}', v) for k, v in vc.items())
        else:
            self.vc = VarCollection(vc)
        self.__wrapped__ = f
Ejemplo n.º 6
0
class ModuleWrapper(Module):
    """Module whose sole purpose is to store a collectable VarCollection. This class is exclusively
    used internally by Objax, for example in Jit, Vectorize and Parallel."""
    def __init__(self, vc: VarCollection):
        super().__init__()
        self.vc = VarCollection(
            (f'({self.__class__.__name__}){k}', v) for k, v in vc.items())

    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the VarCollection.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        return VarCollection((scope + k, v) for k, v in self.vc.items())
Ejemplo n.º 7
0
def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection):
    """Saves variables collection into file.

    Args:
        file: filename or python file handle of the file where variables will be saved.
        vc: variables collection which will be saved into file.
    """
    do_close = isinstance(file, str)
    if do_close:
        filename, file = file, open(
            file + '.tmp', 'wb'
        )  # Save to a temporary in case the job is killed while saving.
    data, names, seen, replicated = {}, [], set(), []
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        if id(v) not in seen:
            names.append(k)
            data[str(len(data))] = v.value
            seen.add(id(v))
        if isinstance(v.value, ShardedDeviceArray):
            replicated.append(k)
    if replicated:
        print(
            'Warning: When saving VarCollection, some variables were replicated on multiple devices.'
        )
        print(
            '         While it is valid, in most use cases it is more disk efficient to save variables outside of '
        )
        print('         vars().replicate().')
    np.savez(file, names=np.array(names), **data)
    if do_close:
        file.close()
        os.rename(
            filename + '.tmp', filename
        )  # Atomic rename to avoid broken file (when killed while saving).
Ejemplo n.º 8
0
 def __init__(self, vc: VarCollection):
     super().__init__()
     self.vc = VarCollection(
         (f'({self.__class__.__name__}){k}', v) for k, v in vc.items())