def test_torch_model_on_begin(self):
     restore_wizard = RestoreWizard(directory=self.system_json_path)
     restore_wizard.system = sample_system_object_torch()
     # save state
     for model in restore_wizard.system.network.models:
         save_model(model,
                    save_dir=restore_wizard.directory,
                    save_optimizer=True)
     restore_wizard.system.save_state(json_path=os.path.join(
         restore_wizard.directory, restore_wizard.system_file))
     restore_wizard.on_begin(data=self.data)
     with self.subTest('Check the restore files directory'):
         self.assertEqual(restore_wizard.directory, self.system_json_path)
     with self.subTest('check data dictionary'):
         self.assertEqual(self.data['epoch'], 0)
     if os.path.exists(self.system_json_path):
         shutil.rmtree(self.system_json_path)
 def test_torch_model_on_epoch_end(self):
     restore_wizard = RestoreWizard(directory=self.system_json_path)
     restore_wizard.system = sample_system_object_torch()
     restore_wizard.on_epoch_end(data=self.data)
     model_names = get_model_name(restore_wizard.system)
     with self.subTest('check json exists'):
         self.assertTrue(
             os.path.exists(
                 os.path.join(self.system_json_path, 'system.json')))
     with self.subTest('Check if model weights path stored'):
         self.assertTrue(
             os.path.exists(
                 os.path.join(self.system_json_path,
                              model_names[0] + '.pt')))
     with self.subTest('Check if model optimizer stored'):
         self.assertTrue(
             os.path.exists(
                 os.path.join(self.system_json_path,
                              model_names[0] + '_opt.pt')))
     if os.path.exists(self.system_json_path):
         shutil.rmtree(self.system_json_path)
Esempio n. 3
0
def get_estimator(epochs=2,
                  batch_size=32,
                  save_dir=tempfile.mkdtemp(),
                  restore_dir=tempfile.mkdtemp()):
    # step 1
    train_data, eval_data = mnist.load_data()
    test_data = eval_data.split(0.5)
    pipeline = fe.Pipeline(train_data=train_data,
                           eval_data=eval_data,
                           test_data=test_data,
                           batch_size=batch_size,
                           ops=[
                               ExpandDims(inputs="x", outputs="x"),
                               Minmax(inputs="x", outputs="x")
                           ])

    # step 2
    model = fe.build(model_fn=LeNet, optimizer_fn="adam")
    network = fe.Network(ops=[
        ModelOp(model=model, inputs="x", outputs="y_pred"),
        CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
        UpdateOp(model=model, loss_name="ce")
    ])
    # step 3
    traces = [
        Accuracy(true_key="y", pred_key="y_pred"),
        BestModelSaver(model=model,
                       save_dir=save_dir,
                       metric="accuracy",
                       save_best_mode="max"),
        LRScheduler(model=model,
                    lr_fn=lambda step: cosine_decay(
                        step, cycle_length=3750, init_lr=1e-3)),
        RestoreWizard(directory=restore_dir)
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces)
    return estimator
    def test_restore(self):
        save_path = tempfile.mkdtemp()
        global_step = 100
        epoch_idx = 10

        restore_wizard = RestoreWizard(directory=save_path)
        restore_wizard.system = sample_system_object()
        restore_wizard.on_begin(Data())
        restore_wizard.system.global_step = global_step
        restore_wizard.system.epoch_idx = epoch_idx
        restore_wizard.on_epoch_end(Data())

        restore_wizard = RestoreWizard(directory=save_path)
        restore_wizard.system = sample_system_object()
        data = Data()
        restore_wizard.on_begin(data)
        with self.subTest("Check print message"):
            self.assertEqual(data['epoch'], 10)
        with self.subTest("Check system variables"):
            self.assertEqual(restore_wizard.system.global_step, global_step)
            self.assertEqual(restore_wizard.system.epoch_idx, epoch_idx)
 def test_save(self):
     save_path = tempfile.mkdtemp()
     restore_wizard = RestoreWizard(directory=save_path)
     restore_wizard.system = sample_system_object()
     restore_wizard.on_begin(Data())
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (1)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'A')))
     with self.subTest("Check Key is Correct (1)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "A")
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (2)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'B')))
     with self.subTest("Check Key is Correct (2)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "B")
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (3)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'A')))
     with self.subTest("Check Key is Correct (3)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "A")
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (4)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'B')))
     with self.subTest("Check Key is Correct (4)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "B")