Ejemplo n.º 1
0
 def test_full_optim_state_dict_keys(self):
     """Tests that the parameter keys returned by
     :meth:`full_optim_state_dict` match those of :meth:`state_dict` with
     full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
     instances and ignored modules."""
     device = torch.device("cuda")
     model = NestedModel().to(device)
     wrapped_model = NestedModel.wrap(model, ignore_modules=True)
     # Add checkpointing to ensure optim_state_dict and state_dict strip out
     # checkpointing prefixes.
     apply_activation_checkpointing_wrapper(
         model,
         check_fn=lambda module: isinstance(module, torch.nn.Sequential))
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._step_model(model, optim, device)
     optim_state_dict = FSDP.full_optim_state_dict(wrapped_model,
                                                   optim,
                                                   rank0_only=False)
     with FSDP.state_dict_type(wrapped_model,
                               StateDictType.FULL_STATE_DICT):
         state_dict = wrapped_model.state_dict()
     self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
     # Check that checkpointing prefix was indeed stripped.
     for key in optim_state_dict["state"]:
         self.assertNotIn(_CHECKPOINT_PREFIX, key)
Ejemplo n.º 2
0
 def _get_full_state_dict_mgr(self, model, state_dict_rank0_and_offload):
     return FSDP.state_dict_type(
         model, StateDictType.FULL_STATE_DICT,
         FullStateDictConfig(
             rank0_only=state_dict_rank0_and_offload,
             offload_to_cpu=state_dict_rank0_and_offload,
         ))
Ejemplo n.º 3
0
 def test_wrong_state_dict_config(self):
     model = FSDP(Model(wrap_fsdp=True).cuda())
     with self.assertRaisesRegex(RuntimeError,
                                 "Expected state_dict_config of type"):
         with model.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                                    LocalStateDictConfig()):
             pass
Ejemplo n.º 4
0
    def test_state_dict_with_ignored_modules(self):
        # Initialize an FSDP-wrapped model with an ignored module
        model = Model(wrap_fsdp=True).cuda()
        ignored_modules = [model.outer]
        ignored_param_to_param_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd = fsdp_model.state_dict()

        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters are not cloned

        for param, param_name in ignored_param_to_param_name.items():
            self.assertTrue(param_name in sd)
            self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        nonwrapped_model.load_state_dict(sd)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
Ejemplo n.º 5
0
 def test_state_dict_with_ignored_modules(self):
     # Initialize an FSDP-wrapped model with an ignored module that includes
     # both parameters and a buffer
     model = Model(wrap_fsdp=True, register_buffers=True).cuda()
     ignored_modules = [model.outer]
     ignored_tensor_to_tensor_name = {
         model.outer.bias: "outer.bias",
         model.outer.weight: "outer.weight",
         model.outer.buffer: "outer.buffer",
     }
     buffer_to_buffer_name = {
         model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer",
     }
     fsdp_model = FSDP(model, ignored_modules=ignored_modules)
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd1 = fsdp_model.state_dict()
     with FSDP.summon_full_params(fsdp_model):
         fsdp_params = deepcopy(list(fsdp_model.parameters()))
     # Check that the ignored parameters and all buffers are not cloned
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)
         self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr())
     # Check that the state dict can be loaded into a non-wrapped version of
     # the model
     nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
     for param in nonwrapped_model.parameters():
         with torch.no_grad():
             param.zero_()
     nonwrapped_model.load_state_dict(sd1)
     local_params = list(nonwrapped_model.parameters())
     for fsdp_param, local_param in zip(fsdp_params, local_params):
         self.assertEqual(fsdp_param, local_param)
     # Check that if we save a state dict again, the ignored parameters and
     # buffer still have the same data pointer
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd2 = fsdp_model.state_dict()
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)  # check again just in case
         self.assertTrue(tensor_name in sd2)
         self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr())
         self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
Ejemplo n.º 6
0
 def test_state_dict_type(self):
     module = SkipModel(double_nest=True)
     with enable_wrap(wrapper_cls=FSDP):
         fsdp = wrap(module)
     with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
         pass
     for module in FSDP.fsdp_modules(fsdp):
         self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
Ejemplo n.º 7
0
    def _state_dict(model: Module, state_dict_type: str):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict type for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.state_dict()
Ejemplo n.º 8
0
    def _load_state_dict(model: Module, state_dict_type: str,
                         state_dict: Dict[str, Any]):
        try:
            enum_val = STATE_DICT_MAPPING[state_dict_type]
        except KeyError:
            raise ValueError(f"No state_dict for {state_dict_type}")

        with FSDP.state_dict_type(model, enum_val):
            return model.load_state_dict(state_dict, strict=True)
Ejemplo n.º 9
0
 def _get_state_dict_mgr(self, model, state_dict_type, state_dict_rank0_and_offload):
     _state_dict_type = STATE_DICT_MAPPING[state_dict_type]
     if state_dict_type == "state_dict":
         config = FullStateDictConfig(
             rank0_only=state_dict_rank0_and_offload,
             offload_to_cpu=state_dict_rank0_and_offload,
         )
     else:
         config = None
     return FSDP.state_dict_type(model, _state_dict_type, config)
Ejemplo n.º 10
0
    def test_distributed_checkpoint(self, state_dict_type) -> None:
        with enable_wrap(wrapper_cls=FSDP):
            torch.manual_seed(100)
            model = wrap(SkipModel(double_nest=True))
            torch.manual_seed(200)
            new_model = wrap(SkipModel(double_nest=True))

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertNotEqual(params, new_params)

        with tempfile.TemporaryDirectory() as path:
            paths = [path]
            dist.broadcast_object_list(paths)
            path = paths[0]
            writer = FileSystemWriter(path)
            reader = FileSystemReader(path)
            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = model.state_dict()

            save_state_dict(state_dict, writer)

            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = new_model.state_dict()
                load_state_dict(state_dict, reader)
                new_model.load_state_dict(state_dict)

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertEqual(params, new_params)
Ejemplo n.º 11
0
 def test_full_optim_state_dict_keys(self):
     """Tests that the parameter keys returned by
     :meth:`full_optim_state_dict` match those of :meth:`state_dict` with
     full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
     instances and ignored modules."""
     device = torch.device("cuda")
     model = NestedModel().to(device)
     wrapped_model = NestedModel.wrap(model, ignore_modules=True)
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._step_model(model, optim, device)
     optim_state_dict = FSDP.full_optim_state_dict(wrapped_model,
                                                   optim,
                                                   rank0_only=False)
     with FSDP.state_dict_type(wrapped_model,
                               StateDictType.FULL_STATE_DICT):
         state_dict = wrapped_model.state_dict()
     self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
Ejemplo n.º 12
0
    def _dist_train(self,
                    wrap_fsdp: bool,
                    state_dict_type: str = "",
                    with_context: bool = False):
        # TODO: Move this test to common_fsdp.
        model = self._initialize_model(wrap_fsdp)
        optim = SGD(model.parameters(), lr=0.1)

        in_data = torch.rand(64,
                             4,
                             requires_grad=True,
                             device=torch.device("cuda"))
        for _ in range(3):
            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        if wrap_fsdp:
            blank_model = FSDP(Model(True).cuda())
            _zero_model(blank_model)
            if with_context:
                state_dict_type = {
                    "state_dict": StateDictType.FULL_STATE_DICT,
                    "local_state_dict": StateDictType.LOCAL_STATE_DICT,
                    "sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
                }[state_dict_type]
                with model.state_dict_type(state_dict_type):
                    state_dict = model.state_dict()
                with blank_model.state_dict_type(state_dict_type):
                    blank_model.load_state_dict(state_dict)
            else:
                state_dict = self._state_dict(model, state_dict_type)
                self._load_state_dict(blank_model, state_dict_type, state_dict)
            return get_full_params(blank_model)
        else:
            return list(model.parameters())
Ejemplo n.º 13
0
    def test_state_dict_with_ignored_modules(self, prefix, ignore_inner):
        # Initialize an FSDP-wrapped model with an ignored module that includes
        # both parameters and a buffer
        model = Model(wrap_fsdp=True,
                      register_buffers=True,
                      ignore_inner=ignore_inner).cuda()
        ignored_modules = [model.outer]
        ignored_tensor_to_tensor_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        if ignore_inner:
            ignored_tensor_to_tensor_name = {
                **ignored_tensor_to_tensor_name,
                model.inner.bias: "inner.bias",
                model.inner.weight: "inner.weight",
            }
        # Note that when model.inner is not ignored this test also ensures
        # non-ignored buffers are not cloned.
        buffer_to_buffer_name = {
            model.inner.buffer: "inner.buffer",
            model.outer.buffer: "outer.buffer",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        prefix_str = "foo." if prefix else ""
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd1 = fsdp_model.state_dict(prefix=prefix_str)
        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters and all buffers are not cloned
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd1)
            self.assertEqual(tensor.data_ptr(),
                             sd1[prefixed_tensor_name].data_ptr(),
                             f"{prefixed_tensor_name}")
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        to_load = {k[len(prefix_str):]: v for k, v in sd1.items()}
        nonwrapped_model.load_state_dict(to_load, strict=True)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
        # Check that if we save a state dict again, the ignored parameters and
        # buffer still have the same data pointer
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd2 = fsdp_model.state_dict(prefix=prefix_str)
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd2)
            self.assertEqual(tensor.data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())
            self.assertEqual(sd1[prefixed_tensor_name].data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())
Ejemplo n.º 14
0
    def test_basic_save_and_load_state_dict(
        self, state_dict_type, cpu_offload, fp16, state_dict_rank0_and_offload
    ):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        for model_call in [
            partial(self._get_simple_nested_model, cpu_offload=cpu_offload),
            partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()

            ctx = self._get_state_dict_mgr(
                model, state_dict_type, state_dict_rank0_and_offload
            )
            with ctx:
                fsdp_state_dict = _get_state_dict(
                    model, cpu_offload.offload_params, fp16
                )

            self._validate_state_dict_contents(
                fsdp_state_dict, state_dict_rank0_and_offload
            )
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)

            with FullyShardedDataParallel.summon_full_params(model):
                with FullyShardedDataParallel.summon_full_params(model_new):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertNotEqual(params, params_new)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()

            with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]):
                model_new.load_state_dict(fsdp_state_dict)
            with FullyShardedDataParallel.summon_full_params(model_new):
                with FullyShardedDataParallel.summon_full_params(model):
                    params = list(model.parameters())
                    params_new = list(model_new.parameters())
                    self.assertEqual(params, params_new)
                    if fp16:
                        for tensor in model_new.parameters():
                            self.assertEqual(tensor.dtype, torch.float16)
Ejemplo n.º 15
0
    def test_state_dict_skip_module(self, state_dict_type, double_nest):
        torch.cuda.set_device(self.rank)

        def _create_module(wrap_fsdp=True):
            LINEAR_SKIP = "linear_skip"
            ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
            with ctx:
                module = SkipModel(double_nest=double_nest)
                # Full name of linear_skip param tensors in SkipModel, as would be
                # stored in checkpoint.
                linear_skip_tensor_names = [
                    k for k in dict(module.named_parameters()).keys()
                    if LINEAR_SKIP in k
                ]
                # skip SkipModule
                linear_skip = getattr(module, LINEAR_SKIP)
                delattr(module, LINEAR_SKIP)
                # Wrap FSDP
                fsdp = wrap(module)
                # reattach
                setattr(module, LINEAR_SKIP, linear_skip)
                return fsdp, linear_skip_tensor_names

        fsdp, linear_skip_tensor_names = _create_module()
        # Run a forward pass
        inp = torch.randn((1, 10), device=torch.cuda.current_device())
        loss = fsdp(inp)
        loss.sum().backward()

        with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]):
            state_dict = fsdp.state_dict()
        if self.rank == 0 and state_dict_type != "local_state_dict":
            sd_keys = list(state_dict.keys())
            expected = list(SkipModel(double_nest=False).state_dict().keys())
            self.assertEqual(sorted(sd_keys), sorted(expected))
            # TODO: parameters in linear_skip_tensor_names should not be handled
            # by FSDP.state_dict(). Have a check once this is implemented in
            # FSDP.state_dict().

        # Check that it can be loaded into FSDP.
        new_fsdp, _ = _create_module()
        _zero_model(new_fsdp)
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertNotEqual(p1, p2)
        with FSDP.state_dict_type(new_fsdp,
                                  STATE_DICT_MAPPING[state_dict_type]):
            if state_dict_type != "local_state_dict":
                # FlatParameter has not supported deepcopy yet.
                state_dict = deepcopy(state_dict)
            new_fsdp.load_state_dict(state_dict, strict=True)
        for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
            self.assertEqual(p1, p2)

        # Test that the checkpoint can be loaded into a local model.
        local, _ = _create_module(wrap_fsdp=False)
        for param in local.parameters():
            with torch.no_grad():
                param.zero_()

        with fsdp.summon_full_params(fsdp):
            for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                self.assertNotEqual(p1, p2)

        if state_dict_type == "local_state_dict":
            return
        state_dict = _gather_state_dict(state_dict)
        with fsdp.summon_full_params(fsdp):
            if self.rank == 0:
                local.load_state_dict(state_dict, strict=True)
                for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
                    self.assertEqual(p1, p2)
Ejemplo n.º 16
0
    def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload,
                                            fp16,
                                            state_dict_rank0_and_offload):
        """
        Tests that we can save a state_dict and load it into a blank model
        with various configs such as fp16 and cpu offload and parameters
        match as expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        for model_call in [
                partial(self._get_non_fsdp_root_module,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_nested_model,
                        cpu_offload=cpu_offload),
                partial(self._get_simple_model, cpu_offload=cpu_offload),
        ]:
            model = model_call()

            ctx = self._get_state_dict_mgr(model, state_dict_type,
                                           state_dict_rank0_and_offload)
            with ctx:
                fsdp_state_dict = _get_state_dict(model,
                                                  cpu_offload.offload_params,
                                                  fp16)

            ignore_keys = [
                k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
            ]

            self._validate_state_dict_contents(
                model,
                fsdp_state_dict,
                state_dict_rank0_and_offload,
                ignore_keys=ignore_keys,
            )
            if fp16:
                # Verify fp16 is the type
                for tensor in fsdp_state_dict.values():
                    self.assertEqual(tensor.dtype, torch.float16)

            model_new = model_call()
            if not cpu_offload.offload_params:
                model_new = model_new.cuda()
            if fp16:
                model_new.half()

            # zero the model to ensure parameters are different.
            _zero_model(model_new)
            self._compare_models(model, model_new, self.assertNotEqual)

            # Verify parameters are the same in the new model.
            if state_dict_rank0_and_offload:
                # Broadcast the state dict and move it back to GPU in
                # preparation for loading.
                if not isinstance(model, FSDP):
                    # Move everything to CPU to avoid running into
                    # https://github.com/pytorch/pytorch/issues/77113, some params
                    # will still be on GPU for non FSDP root modules.
                    for k in fsdp_state_dict.keys():
                        fsdp_state_dict[k] = fsdp_state_dict[k].cpu()
                fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict)
                for key in fsdp_state_dict.keys():
                    fsdp_state_dict[key] = fsdp_state_dict[key].cuda()
            with FSDP.state_dict_type(model_new,
                                      STATE_DICT_MAPPING[state_dict_type]):
                model_new.load_state_dict(fsdp_state_dict, strict=True)

            self._compare_models(model,
                                 model_new,
                                 self.assertEqual,
                                 check_fp16=fp16)
Ejemplo n.º 17
0
    def test_save_and_load_after_forward_state_dict(
            self, state_dict_type, mixed_precision,
            state_dict_rank0_and_offload):
        """
        Test that saving after some training results in params being updated as
        expected.
        """
        if state_dict_rank0_and_offload and state_dict_type != "state_dict":
            return
        torch.cuda.set_device(self.rank)
        mixed_precision = (MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        ) if mixed_precision else None)
        model = self._get_simple_nested_model(mixed_precision=mixed_precision)
        optim = torch.optim.SGD(model.parameters(), lr=0.1)
        initial_params = get_full_params(model)
        for _ in range(6):
            inp = torch.randn(1, 10, device=torch.cuda.current_device())
            output = model(*inp)
            loss = output.sum()
            expected_dtype = torch.float32 if mixed_precision is None else torch.float16
            self.assertEqual(expected_dtype, loss.dtype)
            loss.backward()
            optim.step()

        trained_params = get_full_params(model)
        # Ensure some training occured
        self.assertNotEqual(initial_params, trained_params)
        # Save a copy of the state_dict
        fsd_mgr = self._get_state_dict_mgr(model, state_dict_type,
                                           state_dict_rank0_and_offload)
        with fsd_mgr:
            state_dict = model.state_dict()
            if state_dict_type == "state_dict":
                state_dict = {k: v.clone() for k, v in state_dict.items()}
            else:
                for sharded_tensor in state_dict.values():
                    shard = sharded_tensor._local_shards[0]
                    shard.tensor = shard.tensor.clone().detach_()
        self._validate_state_dict_contents(model, state_dict,
                                           state_dict_rank0_and_offload)
        _zero_model(model)

        # Ensure checkpointed params have the full param dtype
        for tensor in state_dict.values():
            self.assertEqual(tensor.dtype, torch.float32)

        # Load state_dict into zeroed model
        if state_dict_rank0_and_offload:
            # Broadcast the state dict and move it back to GPU in
            # preparation for loading.
            state_dict = self._broadcast_state_dict(state_dict)
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cuda()

        with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]):
            model.load_state_dict(state_dict, strict=True)
        loaded_params = get_full_params(model)
        self.assertEqual(loaded_params, trained_params)