def __exit__(self, type, value, traceback): global _VARSCOPE _VARSCOPE = self._old_varscope from dragon.core.scope import get_tensor_scope, set_tensor_scope prefix = self._name_scope + '/' if self._name_scope != '' else '' assert get_tensor_scope().endswith(prefix) if self._name_scope != '': set_tensor_scope(get_tensor_scope()[:-len(prefix)])
def add_variable(self, name, shape, dtype=None, trainable=True, initializer=None, regularizer=None): if dtype is None: dtype = self.dtype existing_variables = set(tf_variables.global_variables()) with vs.variable_scope(self._scope, reuse=self.built or self._reuse) as scope: with ops.name_scope(scope.original_name_scope): full_name = get_tensor_scope() + name variable = vs.get_variable(name, shape=shape, initializer=initializer, dtype=dtypes.as_dtype(dtype), trainable=trainable and self.trainable) if variable in existing_variables: # Work only if the layer is built return variable if regularizer: raise NotImplementedError() if trainable: self._trainable_weights.append(variable) else: self._non_trainable_weights.append(variable) return variable
def __enter__(self): global _VARSCOPE self._old_varscope = _VARSCOPE _VARSCOPE = self from dragon.core.scope import get_tensor_scope, set_tensor_scope prefix = self._name_scope + '/' if self._name_scope != '' else '' set_tensor_scope(get_tensor_scope() + prefix) return self
def get_variable(self, name, shape=None, dtype=None, initializer=None, trainable=True, collections=None, validate_shape=True, **kwargs): global _VARSTORE # get full name from dragon.core.scope import get_tensor_scope full_name = get_tensor_scope() + name # create a new variable if not full_name in _VARSTORE: if shape is None: raise ValueError( 'Must specific a shape for the Variable({}).'.format( full_name)) if initializer is None: initializer = self._get_default_initializer(name, shape=shape, dtype=dtype) initial_value = initializer(shape, dtype=dtype) new_var = Variable(initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, name=name, dtype=dtype) _VARSTORE[full_name] = new_var return new_var else: # existing ? if self._reuse: return _VARSTORE[full_name] raise ValueError( 'The Variable({}) already exists.'.format(full_name))
def name_scope(self, remove_separator=True): scope = get_tensor_scope() if remove_separator and scope[-1] == '/': scope = scope[:-1] return scope