コード例 #1
0
    def test_get_early_stop_history_list_from_files(self):
        """ Should load fake EarlyStopHistoryList from pth files. """
        plan_fc = [2]
        net0 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        net1 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        history_list = EarlyStopHistoryList()
        history_list.setup(2, 0)
        history_list.histories[0].state_dicts[0] = deepcopy(net0.state_dict())
        history_list.histories[1].state_dicts[0] = deepcopy(net1.state_dict())
        history_list.histories[0].indices[0] = 3
        history_list.histories[1].indices[0] = 42

        specs = get_specs_lenet_toy()
        specs.save_early_stop = True
        specs.net_count = 2
        specs.prune_count = 0

        with TemporaryDirectory() as tmp_dir_name:
            # save checkpoints
            result_saver.save_early_stop_history_list(tmp_dir_name, 'prefix',
                                                      history_list)

            # load and validate histories from file
            experiment_path_prefix = f"{tmp_dir_name}/prefix"
            loaded_history_list = result_loader.get_early_stop_history_list_from_files(
                experiment_path_prefix, specs)
            self.assertEqual(loaded_history_list, history_list)
            net0.load_state_dict(history_list.histories[0].state_dicts[0])
            net1.load_state_dict(history_list.histories[1].state_dicts[0])
コード例 #2
0
    def test_save_nets(self):
        """ Should save two small Lenet instances into pth files. """
        plan_fc = [5]
        net_list = [
            Net(NetNames.LENET,
                DatasetNames.MNIST,
                plan_conv=[],
                plan_fc=plan_fc),
            Net(NetNames.LENET,
                DatasetNames.MNIST,
                plan_conv=[],
                plan_fc=plan_fc)
        ]

        with TemporaryDirectory() as tmp_dir_name:
            result_saver.save_nets(tmp_dir_name, 'prefix', net_list)

            # load and reconstruct nets from their files
            result_file_path0 = os.path.join(tmp_dir_name, 'prefix-net0.pth')
            result_file_path1 = os.path.join(tmp_dir_name, 'prefix-net1.pth')
            for result_file_path in [result_file_path0, result_file_path1]:
                with open(result_file_path, 'rb') as result_file:
                    checkpoint = t_load(result_file)
                    net = Net(NetNames.LENET,
                              DatasetNames.MNIST,
                              plan_conv=[],
                              plan_fc=plan_fc)
                    net.load_state_dict(checkpoint)
コード例 #3
0
def generate_model_from_state_dict(state_dict, specs):
    """ Generate a model specified by 'specs' and load the given 'state_dict'. """
    net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc)
    net.load_state_dict(state_dict)
    net.prune_net(
        0., 0., reset=False)  # apply pruned masks, but do not modify the masks
    return net
コード例 #4
0
 def generate_randomly_reinitialized_net(specs, state_dict):
     """ Build a net from 'state_dict' and randomly reinitialize its weights.
     The net has the same masks like the net specified by 'state_dict'. """
     assert isinstance(
         specs, ExperimentSpecs
     ), f"'specs' needs to be ExperimentSpecs, but is {type(specs)}."
     net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc)
     net.load_state_dict(state_dict)
     net.apply(gaussian_glorot)
     net.store_initial_weights()
     net.prune_net(0.0, 0.0)
     return net
コード例 #5
0
    def test_save_early_stop_history_list(self):
        """ Should save two fake EarlyStopHistories into two pth files. """
        plan_fc = [2]
        net0 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        net1 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        history_list = EarlyStopHistoryList()
        history_list.setup(2, 0)
        history_list.histories[0].state_dicts[0] = deepcopy(net0.state_dict())
        history_list.histories[1].state_dicts[0] = deepcopy(net1.state_dict())
        history_list.histories[0].indices[0] = 3
        history_list.histories[1].indices[0] = 42

        with TemporaryDirectory() as tmp_dir_name:
            result_saver.save_early_stop_history_list(
                tmp_dir_name, 'prefix', history_list)  # save checkpoints

            # load and validate histories from file
            result_file_path0 = os.path.join(tmp_dir_name,
                                             'prefix-early-stop0.pth')
            result_file_path1 = os.path.join(tmp_dir_name,
                                             'prefix-early-stop1.pth')
            for net_num, result_file_path in enumerate(
                [result_file_path0, result_file_path1]):
                with open(result_file_path, 'rb') as result_file:
                    reconstructed_hist = t_load(result_file)
                    net = Net(NetNames.LENET,
                              DatasetNames.MNIST,
                              plan_conv=[],
                              plan_fc=plan_fc)
                    np.testing.assert_array_equal(
                        reconstructed_hist.indices,
                        history_list.histories[net_num].indices)
                    net.load_state_dict(reconstructed_hist.state_dicts[0])