コード例 #1
0
        def _custom_getter(
                getter=None,
                name=None,
                shape=None,
                dtype=dtypes.float32,  # pylint: disable=missing-docstring
                initializer=None,
                regularizer=None,
                reuse=None,
                trainable=True,
                collections=None,
                caching_device=None,  # pylint: disable=redefined-outer-name
                partitioner=None,
                validate_shape=True,
                use_resource=None):
            del getter, regularizer, collections, caching_device, partitioner
            del use_resource, validate_shape
            if name in self.tf_variables:
                if reuse:
                    return self.tf_variables[name].initialized_value()
                else:
                    raise ValueError(
                        "Specified reuse=%s but tried to reuse variables." %
                        reuse)
            # TODO(apassos): ensure this is on the same device as above
            v = _CapturedVariable(name, initializer, shape, dtype, trainable)
            self.variables[name] = v

            graph_mode_resource = v.variable.handle
            if initializer is None:
                initializer = _default_initializer(name, shape, dtype)
            resource_variable_ops.shape_safe_assign_variable_handle(
                graph_mode_resource, v.variable.shape,
                initializer(shape, dtype))
            return v.variable
コード例 #2
0
    def _custom_getter(  # pylint: disable=missing-docstring
        getter=None,
        name=None,
        shape=None,
        dtype=dtypes.float32,
        initializer=None,
        regularizer=None,
        reuse=None,
        trainable=None,
        collections=None,
        caching_device=None,  # pylint: disable=redefined-outer-name
        partitioner=None,
        validate_shape=True,
        use_resource=None,
        aggregation=variable_scope.VariableAggregation.NONE,
        synchronization=variable_scope.VariableSynchronization.AUTO):
      del getter, regularizer, collections, caching_device, partitioner
      del use_resource, validate_shape, aggregation, synchronization
      if name in self.tf_variables:
        if reuse:
          return self.tf_variables[name].initialized_value()
        else:
          raise ValueError("Specified reuse=%s but tried to reuse variables."
                           % reuse)
      # TODO(apassos): ensure this is on the same device as above
      v = _CapturedVariable(name, initializer, shape, dtype, trainable)
      self.variables[name] = v

      graph_mode_resource = v.variable.handle
      if initializer is None:
        initializer = _default_initializer(name, shape, dtype)
      resource_variable_ops.shape_safe_assign_variable_handle(
          graph_mode_resource, v.variable.shape, initializer(shape, dtype))
      return v.variable
コード例 #3
0
 def restore(self, restored_tensors, restored_shapes):
   restored_tensor = restored_tensors[0]
   if restored_shapes is not None:
     restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
   # Copy the restored tensor to the variable's device.
   with ops.device(self._var_device):
     restored_tensor = array_ops.identity(restored_tensor)
     return resource_variable_ops.shape_safe_assign_variable_handle(
         self.handle_op, self._var_shape, restored_tensor)
コード例 #4
0
 def restore(self, restored_tensors, restored_shapes):
   restored_tensor = restored_tensors[0]
   if restored_shapes is not None:
     restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
   # Copy the restored tensor to the variable's device.
   with ops.device(self._var_device):
     restored_tensor = array_ops.identity(restored_tensor)
     return resource_variable_ops.shape_safe_assign_variable_handle(
         self.handle_op, self._var_shape, restored_tensor)
コード例 #5
0
 def restore(self, restored_tensors, restored_shapes):
   """Restores tensors. Raises ValueError if incompatible shape found."""
   restored_tensor = restored_tensors[0]
   if restored_shapes is not None:
     restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
   # Copy the restored tensor to the variable's device.
   with ops.device(self._var_device):
     restored_tensor = array_ops.identity(restored_tensor)
     try:
       assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle(
           self.handle_op, self._var_shape, restored_tensor)
     except ValueError as e:
       raise ValueError(
           f"Received incompatible tensor with shape {restored_tensor.shape} "
           f"when attempting to restore variable with shape {self._var_shape} "
           f"and name {self.name}.") from e
     return assigned_variable