예제 #1
0
파일: base_layer.py 프로젝트: zge/lingvo
    def CopyBaseParams(from_params, to_params):
        """Copies BaseLayer params from `from_params` to `to_params`."""
        assert issubclass(from_params.cls, BaseLayer)
        assert issubclass(to_params.cls, BaseLayer)
        # Copy-over the BaseLayer params.
        if to_params.dtype == tf.float32:
            to_params.dtype = from_params.dtype
        if from_params.fprop_dtype is not None:
            to_params.fprop_dtype = from_params.fprop_dtype
        if to_params.random_seed is None:
            to_params.random_seed = from_params.random_seed
        if to_params.is_eval is None:
            to_params.is_eval = from_params.is_eval
        if to_params.is_inference is None:
            to_params.is_inference = from_params.is_inference
        if to_params.allow_implicit_capture is None:
            to_params.allow_implicit_capture = from_params.allow_implicit_capture
        if to_params.skip_lp_regularization is None:
            to_params.skip_lp_regularization = from_params.skip_lp_regularization

        # Only copy from base when vn config is using the default setting.
        if to_params.vn == DefaultVN():
            to_params.vn = from_params.vn.Copy()

        # TODO(rpang): derive to_params.params_init.seed from
        # from_params.params_init.seed if it is specified in 'from_params' and not
        # in 'to_params'.
        if py_utils.IsDefaultParamInit(to_params.params_init):
            # Copy over params_init as well.
            to_params.params_init = from_params.params_init.Copy()
        return to_params
예제 #2
0
 def testIsDefaultParamInit(self):
   p = py_utils.DefaultParamInit()
   self.assertTrue(py_utils.IsDefaultParamInit(p))