def _run_train(cuda_devices: List[str] = None): serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load") if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) model = ModelDemo() optimizer_factory = _DemoOptimizerFactory() loss = _DemoLoss() metric = _DemoMetric() trainer = Trainer(num_epoch=100, model=model, loss=loss, metrics=metric, optimizer_factory=optimizer_factory, serialize_dir=serialize_dir, patient=20, num_check_point_keep=25, cuda_devices=cuda_devices) train_dataset = _DemoDataset() train_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, shuffle=False, num_workers=0) validation_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, shuffle=False, num_workers=0) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) expect_model_state_dict = json.loads(json2str(trainer.model.state_dict())) expect_optimizer_state_dict = json.loads( json2str(trainer.optimizer.state_dict())) expect_current_epoch = trainer.current_epoch expect_num_epoch = trainer.num_epoch expect_metric = trainer.metrics.metric[0] expect_metric_tracker = json.loads(json2str(trainer.metric_tracker)) trainer.load_checkpoint(serialize_dir=serialize_dir) loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict())) loaded_optimizer_state_dict = json.loads( json2str(trainer.optimizer.state_dict())) current_epoch = trainer.current_epoch num_epoch = trainer.num_epoch metric = trainer.metrics.metric[0] metric_tracker = json.loads(json2str(trainer.metric_tracker)) ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict) ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict) ASSERT.assertEqual(expect_current_epoch, current_epoch) ASSERT.assertEqual(expect_num_epoch, num_epoch) ASSERT.assertDictEqual(expect_metric, metric) ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)
def _run_train(device: torch.device, is_distributed: bool): serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load") if is_distributed: if TorchDist.get_rank() == 0: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) TorchDist.barrier() else: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) model = ModelDemo() optimizer_factory = _DemoOptimizerFactory() loss = _DemoLoss() metric = _DemoMetric() tensorboard_log_dir = "data/tensorboard" tensorboard_log_dir = os.path.join(ROOT_PATH, tensorboard_log_dir) # shutil.rmtree(tensorboard_log_dir) trainer = Trainer(num_epoch=100, model=model, loss=loss, metrics=metric, optimizer_factory=optimizer_factory, serialize_dir=serialize_dir, patient=20, num_check_point_keep=25, device=device, trainer_callback=None, is_distributed=is_distributed ) logging.info(f"test is_distributed: {is_distributed}") # trainer_callback = BasicTrainerCallbackComposite(tensorboard_log_dir=tensorboard_log_dir) train_dataset = _DemoDataset() if is_distributed: sampler = DistributedSampler(dataset=train_dataset) else: sampler = None train_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, num_workers=0, sampler=sampler) if is_distributed: sampler = DistributedSampler(dataset=train_dataset) else: sampler = None validation_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, num_workers=0, sampler=sampler) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) expect_model_state_dict = json.loads(json2str(trainer.model.state_dict())) expect_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict())) expect_current_epoch = trainer.current_epoch expect_num_epoch = trainer.num_epoch expect_metric = trainer.metrics.metric[0] expect_metric_tracker = json.loads(json2str(trainer.metric_tracker)) trainer.load_checkpoint(serialize_dir=serialize_dir) loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict())) loaded_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict())) current_epoch = trainer.current_epoch num_epoch = trainer.num_epoch metric = trainer.metrics.metric[0] metric_tracker = json.loads(json2str(trainer.metric_tracker)) ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict) ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict) ASSERT.assertEqual(expect_current_epoch, current_epoch) ASSERT.assertEqual(expect_num_epoch, num_epoch) ASSERT.assertDictEqual(expect_metric, metric) ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)