def test_torch_model_on_begin(self): restore_wizard = RestoreWizard(directory=self.system_json_path) restore_wizard.system = sample_system_object_torch() # save state for model in restore_wizard.system.network.models: save_model(model, save_dir=restore_wizard.directory, save_optimizer=True) restore_wizard.system.save_state(json_path=os.path.join( restore_wizard.directory, restore_wizard.system_file)) restore_wizard.on_begin(data=self.data) with self.subTest('Check the restore files directory'): self.assertEqual(restore_wizard.directory, self.system_json_path) with self.subTest('check data dictionary'): self.assertEqual(self.data['epoch'], 0) if os.path.exists(self.system_json_path): shutil.rmtree(self.system_json_path)
def test_torch_model_on_epoch_end(self): restore_wizard = RestoreWizard(directory=self.system_json_path) restore_wizard.system = sample_system_object_torch() restore_wizard.on_epoch_end(data=self.data) model_names = get_model_name(restore_wizard.system) with self.subTest('check json exists'): self.assertTrue( os.path.exists( os.path.join(self.system_json_path, 'system.json'))) with self.subTest('Check if model weights path stored'): self.assertTrue( os.path.exists( os.path.join(self.system_json_path, model_names[0] + '.pt'))) with self.subTest('Check if model optimizer stored'): self.assertTrue( os.path.exists( os.path.join(self.system_json_path, model_names[0] + '_opt.pt'))) if os.path.exists(self.system_json_path): shutil.rmtree(self.system_json_path)
def get_estimator(epochs=2, batch_size=32, save_dir=tempfile.mkdtemp(), restore_dir=tempfile.mkdtemp()): # step 1 train_data, eval_data = mnist.load_data() test_data = eval_data.split(0.5) pipeline = fe.Pipeline(train_data=train_data, eval_data=eval_data, test_data=test_data, batch_size=batch_size, ops=[ ExpandDims(inputs="x", outputs="x"), Minmax(inputs="x", outputs="x") ]) # step 2 model = fe.build(model_fn=LeNet, optimizer_fn="adam") network = fe.Network(ops=[ ModelOp(model=model, inputs="x", outputs="y_pred"), CrossEntropy(inputs=("y_pred", "y"), outputs="ce"), UpdateOp(model=model, loss_name="ce") ]) # step 3 traces = [ Accuracy(true_key="y", pred_key="y_pred"), BestModelSaver(model=model, save_dir=save_dir, metric="accuracy", save_best_mode="max"), LRScheduler(model=model, lr_fn=lambda step: cosine_decay( step, cycle_length=3750, init_lr=1e-3)), RestoreWizard(directory=restore_dir) ] estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces) return estimator
def test_restore(self): save_path = tempfile.mkdtemp() global_step = 100 epoch_idx = 10 restore_wizard = RestoreWizard(directory=save_path) restore_wizard.system = sample_system_object() restore_wizard.on_begin(Data()) restore_wizard.system.global_step = global_step restore_wizard.system.epoch_idx = epoch_idx restore_wizard.on_epoch_end(Data()) restore_wizard = RestoreWizard(directory=save_path) restore_wizard.system = sample_system_object() data = Data() restore_wizard.on_begin(data) with self.subTest("Check print message"): self.assertEqual(data['epoch'], 10) with self.subTest("Check system variables"): self.assertEqual(restore_wizard.system.global_step, global_step) self.assertEqual(restore_wizard.system.epoch_idx, epoch_idx)
def test_save(self): save_path = tempfile.mkdtemp() restore_wizard = RestoreWizard(directory=save_path) restore_wizard.system = sample_system_object() restore_wizard.on_begin(Data()) restore_wizard.on_epoch_end(Data()) with self.subTest("Check Saved Files (1)"): self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt'))) self.assertTrue(os.path.exists(os.path.join(save_path, 'A'))) with self.subTest("Check Key is Correct (1)"): with open(os.path.join(save_path, 'key.txt'), 'r') as file: key = file.readline() self.assertEqual(key, "A") restore_wizard.on_epoch_end(Data()) with self.subTest("Check Saved Files (2)"): self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt'))) self.assertTrue(os.path.exists(os.path.join(save_path, 'B'))) with self.subTest("Check Key is Correct (2)"): with open(os.path.join(save_path, 'key.txt'), 'r') as file: key = file.readline() self.assertEqual(key, "B") restore_wizard.on_epoch_end(Data()) with self.subTest("Check Saved Files (3)"): self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt'))) self.assertTrue(os.path.exists(os.path.join(save_path, 'A'))) with self.subTest("Check Key is Correct (3)"): with open(os.path.join(save_path, 'key.txt'), 'r') as file: key = file.readline() self.assertEqual(key, "A") restore_wizard.on_epoch_end(Data()) with self.subTest("Check Saved Files (4)"): self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt'))) self.assertTrue(os.path.exists(os.path.join(save_path, 'B'))) with self.subTest("Check Key is Correct (4)"): with open(os.path.join(save_path, 'key.txt'), 'r') as file: key = file.readline() self.assertEqual(key, "B")