Ejemplo n.º 1
0
    def _get_loss(self, loss_name: str):
        if loss_name == 'logistic':
            x = TrainVar(jn.zeros(2))
            model_vars = VarCollection({'x': x})

            def loss():
                return jn.log(jn.exp(-jn.sum(x.value)) + 1)

            return model_vars, loss
        if loss_name == 'square':
            # loss = x*x + y*y.
            x = TrainVar(jn.ones(2))
            y = TrainVar(jn.ones(3))
            model_vars = VarCollection({'x': x, 'y': y})

            def loss():
                return jn.dot(x.value, x.value) + jn.dot(y.value, y.value)

            return model_vars, loss
        if loss_name == 'rastrigin':
            d = 2
            x = TrainVar(jn.ones(d))
            model_vars = VarCollection({'x': x})

            def loss():
                return 10 * d + jn.dot(x.value, x.value) - 10 * jn.sum(
                    jn.cos(2 * math.pi * x.value))

            return model_vars, loss
        raise ValueError
Ejemplo n.º 2
0
    def __init__(self, f: Callable, vc: VarCollection):
        """Function constructor.

        Args:
            f: the function or the module to represent.
            vc: the VarCollection of variables used by the function.
        """
        if hasattr(f, '__name__'):
            self.vc = VarCollection((f'{{{f.__name__}}}{k}', v) for k, v in vc.items())
        else:
            self.vc = VarCollection(vc)
        self.__wrapped__ = f
Ejemplo n.º 3
0
class AnalyzeUserVariablesNodeTransformer(ast.NodeTransformer):
    def __init__(self, closure_vars, global_vars):
        self.closure_vars = closure_vars
        self.global_vars = global_vars
        self.vc = VarCollection()

    def check_objax_var_module(self, node):
        if not hasanno(node, 'value'):
            return
        v = getanno(node, 'value')
        v_name = getanno(node, 'name')
        if v is None:
            return
        if isinstance(v, Module):
            self.vc.update(v.vars(scope=v_name + '.'))
            setanno(node, 'value', None)
        if isinstance(v, BaseVar):
            if v_name in self.vc and self.vc[v_name] is not v:
                # This generally should not happen and probably indication of a bug somewhere.
                raise ValueError(
                    f'Variable tracing failed because two variables were found with the same name {v_name}'
                )
            else:
                self.vc[v_name] = v
                setanno(node, 'value', None)

    def visit_Name(self, node):
        node = self.generic_visit(node)
        if isinstance(node.ctx, ast.Load):
            if node.id in self.closure_vars:
                setanno(node, 'name', node.id)
                setanno(node, 'value', self.closure_vars[node.id])
                self.check_objax_var_module(node)
            elif node.id in self.global_vars:
                setanno(node, 'name', node.id)
                setanno(node, 'value', self.global_vars[node.id])
                self.check_objax_var_module(node)
        return node

    def visit_Attribute(self, node):
        node = self.generic_visit(node)
        if isinstance(node.ctx, ast.Load) and hasanno(node.value, 'value'):
            parent_value = getanno(node.value, 'value')
            if parent_value is not None and hasattr(parent_value, node.attr):
                setanno(node, 'name',
                        getanno(node.value, 'name') + '.' + node.attr)
                setanno(node, 'value', getattr(parent_value, node.attr))
                self.check_objax_var_module(node)

        return node
Ejemplo n.º 4
0
class Function(Module):
    """Turn a function into a Module by keeping the vars it uses."""
    def __init__(self, f: Callable, vc: VarCollection):
        """Function constructor.

        Args:
            f: the function or the module to represent.
            vc: the VarCollection of variables used by the function.
        """
        if hasattr(f, '__name__'):
            self.vc = VarCollection(
                (f'{{{f.__name__}}}.{k}', v) for k, v in vc.items())
        else:
            self.vc = VarCollection(vc)
        self.__wrapped__ = f

    def __call__(self, *args, **kwargs):
        """Call the the function."""
        return self.__wrapped__(*args, **kwargs)

    def vars(self, scope: str = '') -> VarCollection:
        """Return the VarCollection of the variables used by the function."""
        if scope:
            return VarCollection((scope + k, v) for k, v in self.vc.items())
        return VarCollection(self.vc)

    @staticmethod
    def with_vars(vc: VarCollection):
        """Method to use as decorator in function definitions."""
        def from_function(f: Callable):
            return Function(f, vc)

        return from_function
Ejemplo n.º 5
0
    def __init__(self, f: Callable,
                 variables: Optional[VarCollection],
                 input_argnums: Optional[Tuple[int, ...]] = None):
        """Constructs an instance to compute the gradient of f w.r.t. variables.

        Args:
            f: the function for which to compute gradients.
            variables: the variables for which to compute gradients.
            input_argnums: input indexes, if any, on which to compute gradients.
        """
        variables = variables or VarCollection()
        super().__init__(variables)
        self.input_argnums = input_argnums or tuple()

        def f_func(inputs_and_train_tensors: List[JaxArray],
                   state_tensors: List[JaxArray],
                   list_args: List):
            inputs = inputs_and_train_tensors[:len(self.input_argnums)]
            train_tensors = inputs_and_train_tensors[len(self.input_argnums):]
            original_vc = self.vc.tensors()
            self.vc.subset(TrainVar).assign(train_tensors)
            self.vc.subset(BaseState).assign(state_tensors)
            for i, arg in zip(self.input_argnums, inputs):
                list_args[i] = arg
            outputs = f(*list_args)
            if not isinstance(outputs, (list, tuple)):
                outputs = [outputs]
            return_value = outputs[0], (outputs, variables.tensors(BaseState))
            self.vc.assign(original_vc)
            return return_value

        self.f = jax.grad(f_func, has_aux=True)
Ejemplo n.º 6
0
def load_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection):
    """Loads values of all variables in the given variables collection from file.

    Values loaded from file will replace old values in the variables collection.
    If variable exists in the file, but does not exist in the variables collection it will be ignored.
    If variable exists in the variables collection, but not found in the file then exception will be raised.

    Args:
        file: filename or python file handle of the input file.
        vc: variables collection which will be loaded from file.

    Raises:
        ValueError: if variable from variables collection is not found in the input file.
    """
    do_close = isinstance(file, str)
    if do_close:
        file = open(file, 'rb')
    data = np.load(file, allow_pickle=False)
    name_index = {k: str(i) for i, k in enumerate(data['names'])}
    name_vars = collections.defaultdict(list)
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        name_vars[v].append(k)
    for v, names in name_vars.items():
        for name in names:
            index = name_index.get(name)
            if index is not None:
                v.assign(jn.array(data[index]))
                break
        else:
            raise ValueError(f'Missing value for variables {names}')
    if do_close:
        file.close()
Ejemplo n.º 7
0
def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection):
    """Saves variables collection into file.

    Args:
        file: filename or python file handle of the file where variables will be saved.
        vc: variables collection which will be saved into file.
    """
    do_close = isinstance(file, str)
    if do_close:
        filename, file = file, open(
            file + '.tmp', 'wb'
        )  # Save to a temporary in case the job is killed while saving.
    data, names, seen = {}, [], set()
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        if v not in seen:
            names.append(k)
            data[str(len(data))] = v.value
            seen.add(v)
    np.savez(file, names=np.array(names), **data)
    if do_close:
        file.close()
        os.rename(
            filename + '.tmp', filename
        )  # Atomic rename to avoid broken file (when killed while saving).
Ejemplo n.º 8
0
    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the list and its submodules.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        vc = VarCollection()
        scope += f'({self.__class__.__name__})'
        for p, v in enumerate(self):
            if isinstance(v, BaseVar):
                vc[f'{scope}[{p}]'] = v
            elif isinstance(v, Module):
                vc.update(v.vars(scope=f'{scope}[{p}]'))
        return vc
Ejemplo n.º 9
0
    def __init__(self, vc: VarCollection):
        """Constructor for SGD optimizer.

        Args:
            vc: collection of variables to optimize.
        """
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
Ejemplo n.º 10
0
    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the module and its submodules.
        Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList
        if you need such a feature.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        vc = VarCollection()
        scope += f'({self.__class__.__name__}).'
        for k, v in self.__dict__.items():
            if isinstance(v, BaseVar):
                vc[scope + k] = v
            elif isinstance(v, Module):
                vc.update(v.vars(scope=scope + k))
        return vc
Ejemplo n.º 11
0
    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the VarCollection.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        return VarCollection((scope + k, v) for k, v in self.vc.items())
Ejemplo n.º 12
0
def load_var_collection(file: Union[str, IO[BinaryIO]],
                        vc: VarCollection,
                        renamer: Optional[Renamer] = None):
    """Loads values of all variables in the given variables collection from file.

    Values loaded from file will replace old values in the variables collection.
    If variable exists in the file, but does not exist in the variables collection it will be ignored.
    If variable exists in the variables collection, but not found in the file then exception will be raised.

    Args:
        file: filename or python file handle of the input file.
        vc: variables collection which will be loaded from file.
        renamer: optional renamer to pre-process variables names from the file being read.

    Raises:
        ValueError: if variable from variables collection is not found in the input file.
    """
    renamer = renamer or (lambda x: x)
    do_close = isinstance(file, str)
    if do_close:
        file = open(file, 'rb')
    data = np.load(file, allow_pickle=False)
    name_index = {renamer(k): str(i) for i, k in enumerate(data['names'])}
    var_names = collections.defaultdict(list)
    var_values = {}
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        var_names[id(v)].append(k)
        var_values[id(v)] = v
    misses = []
    used_vars = set()
    for var_id, names in var_names.items():
        v = var_values[var_id]
        for name in names:
            index = name_index.get(name)
            if index is not None:
                used_vars.add(name)
                try:
                    v.assign(jn.array(data[index]))
                except AssertionError as e:
                    raise AssertionError(
                        f'Error when restoring variable {name}: {str(e)}'
                    ) from None
                break
        else:
            misses += names
    if misses:
        not_used = set(name_index.keys()) - used_vars
        raise ValueError(
            f'Missing value for variables currently in the model: {misses}. '
            f'The following variables on disk were not used, '
            f'maybe the missing variable was renamed from one of these: {not_used}.'
        )
    if do_close:
        file.close()
Ejemplo n.º 13
0
    def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8):
        """Constructor for Adam optimizer class.

        Args:
            vc: collection of variables to optimize.
            beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9.
            beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999.
            eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8.
        """
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
        self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
Ejemplo n.º 14
0
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 nesterov: bool = False):
        """Constructor for momentum optimizer class.

        Args:
            vc: collection of variables to optimize.
            momentum: the momentum hyperparameter.
            nesterov: bool indicating whether to use the Nesterov method.
        """
        self.momentum = momentum
        self.nesterov = nesterov
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
Ejemplo n.º 15
0
class ModuleWrapper(Module):
    """Module whose sole purpose is to store a collectable VarCollection. This class is exclusively
    used internally by Objax, for example in Jit, Vectorize and Parallel."""
    def __init__(self, vc: VarCollection):
        super().__init__()
        self.vc = VarCollection(
            (f'({self.__class__.__name__}){k}', v) for k, v in vc.items())

    def vars(self, scope: str = '') -> VarCollection:
        """Collect all the variables (and their names) contained in the VarCollection.

        Args:
            scope: string to prefix to the variable names.
        Returns:
            A VarCollection of all the variables.
        """
        return VarCollection((scope + k, v) for k, v in self.vc.items())
Ejemplo n.º 16
0
    def __init__(self,
                 derivative_fn: Callable,
                 f: Union[Module, Callable],
                 variables: Optional[VarCollection],
                 input_argnums: Optional[Tuple[int, ...]] = None,
                 return_all_f_outputs: bool = False):
        """Constructs an instance to compute the derivatives of f w.r.t. variables.
        Args:
            derivative_fn: JAX transformation which computes derivative.
            f: the function for which to compute derivatives.
            variables: the variables for which to compute derivatives.
            input_argnums: input indexes, if any, on which to compute derivatives.
            return_all_f_outputs: if True also return original outputs of the fuction along with derivatives.
        """
        self.f = f
        self.vc = variables = VarCollection(variables or ())
        if not isinstance(f, Module):
            f = Function(f, self.vc)

        assert isinstance(input_argnums, tuple) or input_argnums is None, \
            f'Must pass a tuple of indices to input_argnums; received {input_argnums}.'
        self.input_argnums = input_argnums or tuple()
        self.return_all_f_outputs = return_all_f_outputs

        def f_func(inputs_and_train_tensors: List[JaxArray],
                   state_tensors: List[JaxArray], list_args: List,
                   kwargs: Dict):
            inputs = inputs_and_train_tensors[:len(self.input_argnums)]
            train_tensors = inputs_and_train_tensors[len(self.input_argnums):]
            original_vc = self.vc.tensors()
            try:
                self.vc.subset(TrainVar).assign(train_tensors)
                self.vc.subset(BaseState).assign(state_tensors)
                for i, arg in zip(self.input_argnums, inputs):
                    list_args[i] = arg
                outputs = f(*list_args, **kwargs)
                if not isinstance(outputs, (list, tuple)):
                    outputs = [outputs]
                if self.return_all_f_outputs:
                    return outputs[0], (outputs, variables.tensors())
                else:
                    return outputs[0], variables.tensors()
            finally:
                self.vc.assign(original_vc)

        self._call = derivative_fn(f_func)
Ejemplo n.º 17
0
    def __init__(self,
                 f: Union[Module, Callable],
                 variables: Optional[VarCollection],
                 input_argnums: Optional[Tuple[int, ...]] = None):
        """Constructs an instance to compute the gradient of f w.r.t. variables.

        Args:
            f: the function for which to compute gradients.
            variables: the variables for which to compute gradients.
            input_argnums: input indexes, if any, on which to compute gradients.
        """
        self.f = f
        self.vc = variables = VarCollection(variables or ())
        if not isinstance(f, Module):
            f = Function(f, self.vc)

        def f_func(inputs_and_train_tensors: List[JaxArray],
                   state_tensors: List[JaxArray], list_args: List,
                   kwargs: Dict):
            inputs = inputs_and_train_tensors[:len(self.input_argnums)]
            train_tensors = inputs_and_train_tensors[len(self.input_argnums):]
            original_vc = self.vc.tensors()
            try:
                self.vc.subset(TrainVar).assign(train_tensors)
                self.vc.subset(BaseState).assign(state_tensors)
                for i, arg in zip(self.input_argnums, inputs):
                    list_args[i] = arg
                outputs = f(*list_args, **kwargs)
                if not isinstance(outputs, (list, tuple)):
                    outputs = [outputs]
                return outputs[0], (outputs, variables.tensors())
            finally:
                self.vc.assign(original_vc)

        assert isinstance(input_argnums, tuple) or input_argnums is None, \
            f'Must pass a tuple of indices to input_argnums; received {input_argnums}.'
        self.input_argnums = input_argnums or tuple()

        signature = inspect.signature(f)
        self.__wrapped__ = f
        self.__signature__ = signature.replace(
            return_annotation=Tuple[List[JaxArray],
                                    signature.return_annotation])
        self._call = jax.grad(f_func, has_aux=True)
Ejemplo n.º 18
0
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 weight_decay: float = 1e-4,
                 tc: float = 1e-3,
                 eps: float = 1e-5):
        """Constructor for LARS optimizer.

        Args:
            vc: collection of variables to optimize.
            momentum: coefficient used for the moving average of the gradient.
            weight_decay: weight decay coefficient.
            tc: trust coefficient eta ( < 1) for trust ratio computation.
            eps: epsilon used for trust ratio computation.
        """
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.tc = tc
        self.eps = eps
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
Ejemplo n.º 19
0
def save_var_collection(file: Union[str, IO[BinaryIO]], vc: VarCollection):
    """Saves variables collection into file.

    Args:
        file: filename or python file handle of the file where variables will be saved.
        vc: variables collection which will be saved into file.
    """
    do_close = isinstance(file, str)
    if do_close:
        filename, file = file, open(
            file + '.tmp', 'wb'
        )  # Save to a temporary in case the job is killed while saving.
    data, names, seen, replicated = {}, [], set(), []
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        if id(v) not in seen:
            names.append(k)
            data[str(len(data))] = v.value
            seen.add(id(v))
        if isinstance(v.value, ShardedDeviceArray):
            replicated.append(k)
    if replicated:
        print(
            'Warning: When saving VarCollection, some variables were replicated on multiple devices.'
        )
        print(
            '         While it is valid, in most use cases it is more disk efficient to save variables outside of '
        )
        print('         vars().replicate().')
    np.savez(file, names=np.array(names), **data)
    if do_close:
        file.close()
        os.rename(
            filename + '.tmp', filename
        )  # Atomic rename to avoid broken file (when killed while saving).
Ejemplo n.º 20
0
 def __init__(self, closure_vars, global_vars):
     self.closure_vars = closure_vars
     self.global_vars = global_vars
     self.vc = VarCollection()
Ejemplo n.º 21
0
 def vars(self, scope: str = '') -> VarCollection:
     """Return the VarCollection of the variables used."""
     if scope:
         return VarCollection((scope + k, v) for k, v in self.vc.items())
     return VarCollection(self.vc)
Ejemplo n.º 22
0
 def __init__(self, vc: VarCollection, base_optimizer: Callable, **kwargs):
     self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
     self.base_optimizer = base_optimizer(vc, **kwargs)
     self.state = defaultdict(dict)
Ejemplo n.º 23
0
 def __init__(self, vc: VarCollection):
     super().__init__()
     self.vc = VarCollection(
         (f'({self.__class__.__name__}){k}', v) for k, v in vc.items())