コード例 #1
0
    def CopyBaseParams(from_params, to_params):
        """Copies BaseLayer params from `from_params` to `to_params`."""
        assert issubclass(from_params.cls, BaseLayer)
        assert issubclass(to_params.cls, BaseLayer)
        # Copy-over the BaseLayer params.
        if to_params.dtype == tf.float32:
            to_params.dtype = from_params.dtype
        if from_params.fprop_dtype is not None:
            to_params.fprop_dtype = from_params.fprop_dtype
        if to_params.random_seed is None:
            to_params.random_seed = from_params.random_seed
        if to_params.is_inference is None:
            to_params.is_inference = from_params.is_inference
        if to_params.allow_implicit_capture is None:
            to_params.allow_implicit_capture = from_params.allow_implicit_capture
        if to_params.skip_lp_regularization is None:
            to_params.skip_lp_regularization = from_params.skip_lp_regularization

        # Only copy from base when vn config is using the default setting.
        if to_params.vn == py_utils.DefaultVN():
            to_params.vn = from_params.vn.Copy()

        # TODO(rpang): derive to_params.params_init.seed from
        # from_params.params_init.seed if it is specified in 'from_params' and not
        # in 'to_params'.
        if py_utils.IsDefaultParamInit(to_params.params_init):
            # Copy over params_init as well.
            to_params.params_init = from_params.params_init.Copy()
        return to_params
コード例 #2
0
 def Params(cls: Type[BaseLayerT]) -> BaseLayerParamsT:
     """Returns the layer params."""
     p = hyperparams.InstantiableParams(cls)
     p.Define('inference_driver_name', cls._INFERENCE_DRIVER_NAME,
              'Name of the inference driver used to construct this layer.')
     p.Define('name', '', 'Name of this layer object.')
     p.Define('dtype', tf.float32, 'Datatype to use.')
     # None value will make FProp use dtype instead of fprop_dtype.
     # TODO(lepikhin): all @tf.Defun should use p.fprop_dtype if it is set.
     p.Define('fprop_dtype', None, 'Activations datatype to use.')
     p.Define(
         'random_seed', None,
         'Random seed for deterministic unittests. This '
         'is inherited by child layers if they do not set a random_seed.')
     p.Define('vn', py_utils.DefaultVN(),
              'How variational noise should be applied.')
     p.Define(
         'params_init', py_utils.DefaultParamInit(),
         'How model weights should be initialized. Not to be confused with '
         'hyperparams.')
     p.Define('add_name_to_theta', False,
              'Wrap theta with tf.identity(var_name).')
     # Makes additional alterations for graphs being used for inference.
     p.Define('is_inference', None, 'True if in inference mode.')
     # In addition to is_inference, indicate that the inference graph is
     # for a single step.
     p.Define(
         'allow_implicit_capture', None,
         'When using Defuns, code often asserts that the Defun does not '
         'capture undeclared inputs. This eliminates a source of bugs '
         'at the expense of making some kinds of models or utilities '
         'hard/impossible to use. Setting this to True/False (versus None) '
         'causes the setting to apply to this layer and its children.')
     p.Define(
         'skip_lp_regularization', None,
         'If True, all variables in this layer will skip Lp regularization. '
         'If None/False, only variables explicitly in the '
         'SKIP_LP_REGULARIZATION collection will skip Lp regularization. '
         'Also propagated to child layers with default settings (None).')
     # SPMD partition related params.
     p.Define(
         'device_mesh', None,
         'A numpy.ndarray specifying the topology of a device mesh to place the'
         ' computations onto. If device_mesh is None, it is assumed to be a'
         ' single device. Here are some examples:'
         ' np.array([0, 1, 2, 3, 4, 5, 6, 7]) which is a 1d mesh with 8 devices,'
         ' np.array([[0, 1, 2, 3], [4, 5, 6, 7]]) which is 2d matrix of 8'
         ' devices.')
     p.Define(
         'weight_split_dims_mapping', None,
         'Relevant only if device_mesh above is not None. If not None, it '
         'specifies how weight of this layer or those of the sublayers should '
         'be sharded over device mesh. ')
     p.Define(
         'activation_split_dims_mapping', None,
         'Relevant only if device_mesh above is not None. If not None, it '
         'specifies how activation of this layer or those of the sublayers '
         'should be sharded over device mesh. ')
     return p
コード例 #3
0
 def Params(cls):
     """Returns the layer params."""
     p = hyperparams.InstantiableParams(cls)
     p.Define('inference_driver_name', cls._INFERENCE_DRIVER_NAME,
              'Name of the inference driver used to construct this layer.')
     p.Define('name', '', 'Name of this layer object.')
     p.Define('dtype', tf.float32, 'Datatype to use.')
     # None value will make FProp use dtype instead of fprop_dtype.
     # TODO(lepikhin): all @tf.Defun should use p.fprop_dtype if it is set.
     p.Define('fprop_dtype', None, 'Activations datatype to use.')
     p.Define(
         'random_seed', None,
         'Random seed for deterministic unittests. This '
         'is inherited by child layers if they do not set a random_seed.')
     p.Define('vn', py_utils.DefaultVN(),
              'How variational noise should be applied.')
     p.Define(
         'params_init', py_utils.DefaultParamInit(),
         'How model weights should be initialized. Not to be confused with '
         'hyperparams.')
     # Makes additional alterations for graphs being used for inference.
     p.Define('is_inference', None, 'True if in inference mode.')
     # In addition to is_inference, indicate that the inference graph is
     # for a single step.
     p.Define(
         'allow_implicit_capture', None,
         'When using Defuns, code often asserts that the Defun does not '
         'capture undeclared inputs. This eliminates a source of bugs '
         'at the expense of making some kinds of models or utilities '
         'hard/impossible to use. Setting this to True/False (versus None) '
         'causes the setting to apply to this layer and its children.')
     p.Define(
         'skip_lp_regularization', None,
         'If True, all variables in this layer will skip Lp regularization. '
         'If None/False, only variables explicitly in the '
         'SKIP_LP_REGULARIZATION collection will skip Lp regularization. '
         'Also propagated to child layers with default settings (None).')
     return p