def test_get_early_stop_history_list_from_files(self): """ Should load fake EarlyStopHistoryList from pth files. """ plan_fc = [2] net0 = Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc) net1 = Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc) history_list = EarlyStopHistoryList() history_list.setup(2, 0) history_list.histories[0].state_dicts[0] = deepcopy(net0.state_dict()) history_list.histories[1].state_dicts[0] = deepcopy(net1.state_dict()) history_list.histories[0].indices[0] = 3 history_list.histories[1].indices[0] = 42 specs = get_specs_lenet_toy() specs.save_early_stop = True specs.net_count = 2 specs.prune_count = 0 with TemporaryDirectory() as tmp_dir_name: # save checkpoints result_saver.save_early_stop_history_list(tmp_dir_name, 'prefix', history_list) # load and validate histories from file experiment_path_prefix = f"{tmp_dir_name}/prefix" loaded_history_list = result_loader.get_early_stop_history_list_from_files( experiment_path_prefix, specs) self.assertEqual(loaded_history_list, history_list) net0.load_state_dict(history_list.histories[0].state_dicts[0]) net1.load_state_dict(history_list.histories[1].state_dicts[0])
def test_save_nets(self): """ Should save two small Lenet instances into pth files. """ plan_fc = [5] net_list = [ Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc), Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc) ] with TemporaryDirectory() as tmp_dir_name: result_saver.save_nets(tmp_dir_name, 'prefix', net_list) # load and reconstruct nets from their files result_file_path0 = os.path.join(tmp_dir_name, 'prefix-net0.pth') result_file_path1 = os.path.join(tmp_dir_name, 'prefix-net1.pth') for result_file_path in [result_file_path0, result_file_path1]: with open(result_file_path, 'rb') as result_file: checkpoint = t_load(result_file) net = Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc) net.load_state_dict(checkpoint)
def generate_model_from_state_dict(state_dict, specs): """ Generate a model specified by 'specs' and load the given 'state_dict'. """ net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc) net.load_state_dict(state_dict) net.prune_net( 0., 0., reset=False) # apply pruned masks, but do not modify the masks return net
def generate_randomly_reinitialized_net(specs, state_dict): """ Build a net from 'state_dict' and randomly reinitialize its weights. The net has the same masks like the net specified by 'state_dict'. """ assert isinstance( specs, ExperimentSpecs ), f"'specs' needs to be ExperimentSpecs, but is {type(specs)}." net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc) net.load_state_dict(state_dict) net.apply(gaussian_glorot) net.store_initial_weights() net.prune_net(0.0, 0.0) return net
def test_save_early_stop_history_list(self): """ Should save two fake EarlyStopHistories into two pth files. """ plan_fc = [2] net0 = Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc) net1 = Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc) history_list = EarlyStopHistoryList() history_list.setup(2, 0) history_list.histories[0].state_dicts[0] = deepcopy(net0.state_dict()) history_list.histories[1].state_dicts[0] = deepcopy(net1.state_dict()) history_list.histories[0].indices[0] = 3 history_list.histories[1].indices[0] = 42 with TemporaryDirectory() as tmp_dir_name: result_saver.save_early_stop_history_list( tmp_dir_name, 'prefix', history_list) # save checkpoints # load and validate histories from file result_file_path0 = os.path.join(tmp_dir_name, 'prefix-early-stop0.pth') result_file_path1 = os.path.join(tmp_dir_name, 'prefix-early-stop1.pth') for net_num, result_file_path in enumerate( [result_file_path0, result_file_path1]): with open(result_file_path, 'rb') as result_file: reconstructed_hist = t_load(result_file) net = Net(NetNames.LENET, DatasetNames.MNIST, plan_conv=[], plan_fc=plan_fc) np.testing.assert_array_equal( reconstructed_hist.indices, history_list.histories[net_num].indices) net.load_state_dict(reconstructed_hist.state_dicts[0])