def CopyBaseParams(from_params, to_params): """Copies BaseLayer params from `from_params` to `to_params`.""" assert issubclass(from_params.cls, BaseLayer) assert issubclass(to_params.cls, BaseLayer) # Copy-over the BaseLayer params. if to_params.dtype == tf.float32: to_params.dtype = from_params.dtype if from_params.fprop_dtype is not None: to_params.fprop_dtype = from_params.fprop_dtype if to_params.random_seed is None: to_params.random_seed = from_params.random_seed if to_params.is_eval is None: to_params.is_eval = from_params.is_eval if to_params.is_inference is None: to_params.is_inference = from_params.is_inference if to_params.allow_implicit_capture is None: to_params.allow_implicit_capture = from_params.allow_implicit_capture if to_params.skip_lp_regularization is None: to_params.skip_lp_regularization = from_params.skip_lp_regularization # Only copy from base when vn config is using the default setting. if to_params.vn == DefaultVN(): to_params.vn = from_params.vn.Copy() # TODO(rpang): derive to_params.params_init.seed from # from_params.params_init.seed if it is specified in 'from_params' and not # in 'to_params'. if py_utils.IsDefaultParamInit(to_params.params_init): # Copy over params_init as well. to_params.params_init = from_params.params_init.Copy() return to_params
def testIsDefaultParamInit(self): p = py_utils.DefaultParamInit() self.assertTrue(py_utils.IsDefaultParamInit(p))