def _custom_getter( getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring initializer=None, regularizer=None, reuse=None, trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name partitioner=None, validate_shape=True, use_resource=None): del getter, regularizer, collections, caching_device, partitioner del use_resource, validate_shape if name in self.tf_variables: if reuse: return self.tf_variables[name].initialized_value() else: raise ValueError( "Specified reuse=%s but tried to reuse variables." % reuse) # TODO(apassos): ensure this is on the same device as above v = _CapturedVariable(name, initializer, shape, dtype, trainable) self.variables[name] = v graph_mode_resource = v.variable.handle if initializer is None: initializer = _default_initializer(name, shape, dtype) resource_variable_ops.shape_safe_assign_variable_handle( graph_mode_resource, v.variable.shape, initializer(shape, dtype)) return v.variable
def _custom_getter( # pylint: disable=missing-docstring getter=None, name=None, shape=None, dtype=dtypes.float32, initializer=None, regularizer=None, reuse=None, trainable=None, collections=None, caching_device=None, # pylint: disable=redefined-outer-name partitioner=None, validate_shape=True, use_resource=None, aggregation=variable_scope.VariableAggregation.NONE, synchronization=variable_scope.VariableSynchronization.AUTO): del getter, regularizer, collections, caching_device, partitioner del use_resource, validate_shape, aggregation, synchronization if name in self.tf_variables: if reuse: return self.tf_variables[name].initialized_value() else: raise ValueError("Specified reuse=%s but tried to reuse variables." % reuse) # TODO(apassos): ensure this is on the same device as above v = _CapturedVariable(name, initializer, shape, dtype, trainable) self.variables[name] = v graph_mode_resource = v.variable.handle if initializer is None: initializer = _default_initializer(name, shape, dtype) resource_variable_ops.shape_safe_assign_variable_handle( graph_mode_resource, v.variable.shape, initializer(shape, dtype)) return v.variable
def restore(self, restored_tensors, restored_shapes): restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) # Copy the restored tensor to the variable's device. with ops.device(self._var_device): restored_tensor = array_ops.identity(restored_tensor) return resource_variable_ops.shape_safe_assign_variable_handle( self.handle_op, self._var_shape, restored_tensor)
def restore(self, restored_tensors, restored_shapes): restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) # Copy the restored tensor to the variable's device. with ops.device(self._var_device): restored_tensor = array_ops.identity(restored_tensor) return resource_variable_ops.shape_safe_assign_variable_handle( self.handle_op, self._var_shape, restored_tensor)
def restore(self, restored_tensors, restored_shapes): """Restores tensors. Raises ValueError if incompatible shape found.""" restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) # Copy the restored tensor to the variable's device. with ops.device(self._var_device): restored_tensor = array_ops.identity(restored_tensor) try: assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle( self.handle_op, self._var_shape, restored_tensor) except ValueError as e: raise ValueError( f"Received incompatible tensor with shape {restored_tensor.shape} " f"when attempting to restore variable with shape {self._var_shape} " f"and name {self.name}.") from e return assigned_variable