コード例 #1
0
    def test_get_early_stop_history_list_from_files(self):
        """ Should load fake EarlyStopHistoryList from pth files. """
        plan_fc = [2]
        net0 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        net1 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        history_list = EarlyStopHistoryList()
        history_list.setup(2, 0)
        history_list.histories[0].state_dicts[0] = deepcopy(net0.state_dict())
        history_list.histories[1].state_dicts[0] = deepcopy(net1.state_dict())
        history_list.histories[0].indices[0] = 3
        history_list.histories[1].indices[0] = 42

        specs = get_specs_lenet_toy()
        specs.save_early_stop = True
        specs.net_count = 2
        specs.prune_count = 0

        with TemporaryDirectory() as tmp_dir_name:
            # save checkpoints
            result_saver.save_early_stop_history_list(tmp_dir_name, 'prefix',
                                                      history_list)

            # load and validate histories from file
            experiment_path_prefix = f"{tmp_dir_name}/prefix"
            loaded_history_list = result_loader.get_early_stop_history_list_from_files(
                experiment_path_prefix, specs)
            self.assertEqual(loaded_history_list, history_list)
            net0.load_state_dict(history_list.histories[0].state_dicts[0])
            net1.load_state_dict(history_list.histories[1].state_dicts[0])
コード例 #2
0
    def test_perform_toy_lenet_experiment(self):
        """ Should run IMP-Experiment with small Lenet and toy-dataset without errors. """
        specs = get_specs_lenet_toy()
        specs.prune_count = 1
        specs.save_early_stop = True

        early_stop_history = EarlyStopHistory()
        early_stop_history.setup(specs.prune_count)

        net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc)
        early_stop_history.state_dicts[0] = net.state_dict()
        early_stop_history.state_dicts[1] = net.state_dict()
        early_stop_history_list = EarlyStopHistoryList()
        early_stop_history_list.setup(1, 0)
        early_stop_history_list.histories[0] = early_stop_history

        fake_mnist_data_loaders = generate_fake_mnist_data_loaders()
        with mock.patch('experiments.experiment.get_mnist_data_loaders',
                        return_value=fake_mnist_data_loaders):
            with TemporaryDirectory(
            ) as tmp_dir_name:  # save results into a temporary folder
                result_saver.save_specs(tmp_dir_name, 'prefix', specs)
                result_saver.save_early_stop_history_list(
                    tmp_dir_name, 'prefix', early_stop_history_list)
                path_to_specs = os.path.join(tmp_dir_name, 'prefix-specs.json')
                experiment = ExperimentRandomRetrain(path_to_specs, 0, 1)
                experiment.run_experiment()
                self.assertEqual(
                    1,
                    len(
                        glob.glob(
                            os.path.join(tmp_dir_name,
                                         'prefix-random-histories0.npz'))))
コード例 #3
0
    def test_setup_early_stop_history_list(self):
        """ Should setup all np.arrays correctly. """
        histories = EarlyStopHistoryList()
        histories.setup(1, 0)

        self.assertIsInstance(histories.histories[0], EarlyStopHistory)
        self.assertIsInstance(histories.histories, np.ndarray)
        self.assertEqual(1, len(histories.histories))
コード例 #4
0
def get_early_stop_history_list_from_files(experiment_path_prefix, specs):
    """ Read all EarlyStopHistory objects related to 'specs' from pth-files and return one EarlyStopHistoryList. """
    assert isinstance(
        specs, ExperimentSpecs), f"'specs' has invalid type {type(specs)}."
    assert specs.save_early_stop, f"'save_early_stop' is False in given 'specs', i.e. no EarlyStopHistoryList exists."

    history_list = EarlyStopHistoryList()
    history_list.setup(specs.net_count, specs.prune_count)

    for net_number in range(specs.net_count):
        early_stop_file_path = generate_early_stop_file_path(
            experiment_path_prefix, net_number)
        history_list.histories[net_number] = torch.load(
            early_stop_file_path, map_location=torch.device("cpu"))
    return history_list
コード例 #5
0
    def test_save_early_stop_history_list(self):
        """ Should save two fake EarlyStopHistories into two pth files. """
        plan_fc = [2]
        net0 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        net1 = Net(NetNames.LENET,
                   DatasetNames.MNIST,
                   plan_conv=[],
                   plan_fc=plan_fc)
        history_list = EarlyStopHistoryList()
        history_list.setup(2, 0)
        history_list.histories[0].state_dicts[0] = deepcopy(net0.state_dict())
        history_list.histories[1].state_dicts[0] = deepcopy(net1.state_dict())
        history_list.histories[0].indices[0] = 3
        history_list.histories[1].indices[0] = 42

        with TemporaryDirectory() as tmp_dir_name:
            result_saver.save_early_stop_history_list(
                tmp_dir_name, 'prefix', history_list)  # save checkpoints

            # load and validate histories from file
            result_file_path0 = os.path.join(tmp_dir_name,
                                             'prefix-early-stop0.pth')
            result_file_path1 = os.path.join(tmp_dir_name,
                                             'prefix-early-stop1.pth')
            for net_num, result_file_path in enumerate(
                [result_file_path0, result_file_path1]):
                with open(result_file_path, 'rb') as result_file:
                    reconstructed_hist = t_load(result_file)
                    net = Net(NetNames.LENET,
                              DatasetNames.MNIST,
                              plan_conv=[],
                              plan_fc=plan_fc)
                    np.testing.assert_array_equal(
                        reconstructed_hist.indices,
                        history_list.histories[net_num].indices)
                    net.load_state_dict(reconstructed_hist.state_dicts[0])
コード例 #6
0
    def __init__(self, specs):
        super(Experiment, self).__init__()
        self.specs = specs
        log_from_medium(self.specs.verbosity, specs)

        self.device = torch.device(specs.device)

        # setup epoch_length and trainer in load_data_and_setup_trainer()
        self.trainer = None
        self.epoch_length = 0

        # setup history-arrays in setup_experiment()
        self.hists = ExperimentHistories()
        self.stop_hists = EarlyStopHistoryList()
コード例 #7
0
    def test_early_stop_history_lists_are_unequal(self):
        history_list0, history_list1 = EarlyStopHistoryList(
        ), EarlyStopHistoryList()
        history_list0.setup(2, 1)
        history_list1.setup(2, 1)
        history_list1.histories[0].indices[0] = 5

        self.assertIs(
            EarlyStopHistoryList.__eq__(history_list0, history_list1), False)
コード例 #8
0
 def test_early_stop_history_lists_are_equal(self):
     """ Should return True, because both EarlyStopHistoryLists are the same. """
     history_list = EarlyStopHistoryList()
     history_list.setup(2, 1)
     self.assertIs(EarlyStopHistoryList.__eq__(history_list, history_list),
                   True)