def CopyBaseParams(from_params, to_params): """Copies BaseLayer params from `from_params` to `to_params`.""" assert issubclass(from_params.cls, BaseLayer) assert issubclass(to_params.cls, BaseLayer) # Copy-over the BaseLayer params. if to_params.dtype == tf.float32: to_params.dtype = from_params.dtype if from_params.fprop_dtype is not None: to_params.fprop_dtype = from_params.fprop_dtype if to_params.random_seed is None: to_params.random_seed = from_params.random_seed if to_params.is_inference is None: to_params.is_inference = from_params.is_inference if to_params.allow_implicit_capture is None: to_params.allow_implicit_capture = from_params.allow_implicit_capture if to_params.skip_lp_regularization is None: to_params.skip_lp_regularization = from_params.skip_lp_regularization # Only copy from base when vn config is using the default setting. if to_params.vn == py_utils.DefaultVN(): to_params.vn = from_params.vn.Copy() # TODO(rpang): derive to_params.params_init.seed from # from_params.params_init.seed if it is specified in 'from_params' and not # in 'to_params'. if py_utils.IsDefaultParamInit(to_params.params_init): # Copy over params_init as well. to_params.params_init = from_params.params_init.Copy() return to_params
def Params(cls: Type[BaseLayerT]) -> BaseLayerParamsT: """Returns the layer params.""" p = hyperparams.InstantiableParams(cls) p.Define('inference_driver_name', cls._INFERENCE_DRIVER_NAME, 'Name of the inference driver used to construct this layer.') p.Define('name', '', 'Name of this layer object.') p.Define('dtype', tf.float32, 'Datatype to use.') # None value will make FProp use dtype instead of fprop_dtype. # TODO(lepikhin): all @tf.Defun should use p.fprop_dtype if it is set. p.Define('fprop_dtype', None, 'Activations datatype to use.') p.Define( 'random_seed', None, 'Random seed for deterministic unittests. This ' 'is inherited by child layers if they do not set a random_seed.') p.Define('vn', py_utils.DefaultVN(), 'How variational noise should be applied.') p.Define( 'params_init', py_utils.DefaultParamInit(), 'How model weights should be initialized. Not to be confused with ' 'hyperparams.') p.Define('add_name_to_theta', False, 'Wrap theta with tf.identity(var_name).') # Makes additional alterations for graphs being used for inference. p.Define('is_inference', None, 'True if in inference mode.') # In addition to is_inference, indicate that the inference graph is # for a single step. p.Define( 'allow_implicit_capture', None, 'When using Defuns, code often asserts that the Defun does not ' 'capture undeclared inputs. This eliminates a source of bugs ' 'at the expense of making some kinds of models or utilities ' 'hard/impossible to use. Setting this to True/False (versus None) ' 'causes the setting to apply to this layer and its children.') p.Define( 'skip_lp_regularization', None, 'If True, all variables in this layer will skip Lp regularization. ' 'If None/False, only variables explicitly in the ' 'SKIP_LP_REGULARIZATION collection will skip Lp regularization. ' 'Also propagated to child layers with default settings (None).') # SPMD partition related params. p.Define( 'device_mesh', None, 'A numpy.ndarray specifying the topology of a device mesh to place the' ' computations onto. If device_mesh is None, it is assumed to be a' ' single device. Here are some examples:' ' np.array([0, 1, 2, 3, 4, 5, 6, 7]) which is a 1d mesh with 8 devices,' ' np.array([[0, 1, 2, 3], [4, 5, 6, 7]]) which is 2d matrix of 8' ' devices.') p.Define( 'weight_split_dims_mapping', None, 'Relevant only if device_mesh above is not None. If not None, it ' 'specifies how weight of this layer or those of the sublayers should ' 'be sharded over device mesh. ') p.Define( 'activation_split_dims_mapping', None, 'Relevant only if device_mesh above is not None. If not None, it ' 'specifies how activation of this layer or those of the sublayers ' 'should be sharded over device mesh. ') return p
def Params(cls): """Returns the layer params.""" p = hyperparams.InstantiableParams(cls) p.Define('inference_driver_name', cls._INFERENCE_DRIVER_NAME, 'Name of the inference driver used to construct this layer.') p.Define('name', '', 'Name of this layer object.') p.Define('dtype', tf.float32, 'Datatype to use.') # None value will make FProp use dtype instead of fprop_dtype. # TODO(lepikhin): all @tf.Defun should use p.fprop_dtype if it is set. p.Define('fprop_dtype', None, 'Activations datatype to use.') p.Define( 'random_seed', None, 'Random seed for deterministic unittests. This ' 'is inherited by child layers if they do not set a random_seed.') p.Define('vn', py_utils.DefaultVN(), 'How variational noise should be applied.') p.Define( 'params_init', py_utils.DefaultParamInit(), 'How model weights should be initialized. Not to be confused with ' 'hyperparams.') # Makes additional alterations for graphs being used for inference. p.Define('is_inference', None, 'True if in inference mode.') # In addition to is_inference, indicate that the inference graph is # for a single step. p.Define( 'allow_implicit_capture', None, 'When using Defuns, code often asserts that the Defun does not ' 'capture undeclared inputs. This eliminates a source of bugs ' 'at the expense of making some kinds of models or utilities ' 'hard/impossible to use. Setting this to True/False (versus None) ' 'causes the setting to apply to this layer and its children.') p.Define( 'skip_lp_regularization', None, 'If True, all variables in this layer will skip Lp regularization. ' 'If None/False, only variables explicitly in the ' 'SKIP_LP_REGULARIZATION collection will skip Lp regularization. ' 'Also propagated to child layers with default settings (None).') return p