def test_metric_tracker_patient(): metric_tracker = MetricTracker(patient=1) for metric in METRICS: metric_tracker.add_metric(**metric) if metric["epoch"] > 4: ASSERT.assertTrue(metric_tracker.early_stopping(metric["epoch"])) else: ASSERT.assertFalse(metric_tracker.early_stopping(metric["epoch"])) if metric_tracker.early_stopping(metric["epoch"]): break expect = {"epoch": 3, "train_metric": {"acc": 0.85}, "train_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.85), "validation_metric": {"acc": 0.60}, "validation_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.60)} best = metric_tracker.best() ASSERT.assertEqual(expect["epoch"], best.epoch) ASSERT.assertDictEqual(expect["train_metric"], best.train_metric) ASSERT.assertDictEqual(expect["validation_metric"], best.validation_metric) ASSERT.assertEqual(expect["train_model_target_metric"].name, best.train_model_target_metric.name) ASSERT.assertEqual(expect["train_model_target_metric"].value, best.train_model_target_metric.value) ASSERT.assertEqual(expect["validation_model_target_metric"].name, best.validation_model_target_metric.name) ASSERT.assertEqual(expect["validation_model_target_metric"].value, best.validation_model_target_metric.value)
def test_metric_tracker_best(): """ 测试 metric tracker :return: """ metric_tracker = MetricTracker(patient=None) for metric in METRICS: metric_tracker.add_metric(**metric) expect = {"epoch": 3, "train_metric": {"acc": 0.85}, "train_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.85), "validation_metric": {"acc": 0.60}, "validation_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.60)} best = metric_tracker.best() ASSERT.assertEqual(expect["epoch"], best.epoch) ASSERT.assertDictEqual(expect["train_metric"], best.train_metric) ASSERT.assertDictEqual(expect["validation_metric"], best.validation_metric) ASSERT.assertEqual(expect["train_model_target_metric"].name, best.train_model_target_metric.name) ASSERT.assertEqual(expect["train_model_target_metric"].value, best.train_model_target_metric.value) ASSERT.assertEqual(expect["validation_model_target_metric"].name, best.validation_model_target_metric.name) ASSERT.assertEqual(expect["validation_model_target_metric"].value, best.validation_model_target_metric.value)
def test_metric_tracker_save_and_load(): metric_tracker = MetricTracker(patient=1) for metric in METRICS: metric_tracker.add_metric(**metric) if metric["epoch"] > 4: ASSERT.assertTrue(metric_tracker.early_stopping(metric["epoch"])) else: ASSERT.assertFalse(metric_tracker.early_stopping(metric["epoch"])) if metric_tracker.early_stopping(metric["epoch"]): break saved_file_path = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/metric_tracker.json") metric_tracker.save(saved_file_path) loaded_metric_tracker = MetricTracker.from_file(saved_file_path) best = metric_tracker.best() loaded_best = loaded_metric_tracker.best() ASSERT.assertEqual(best.epoch, loaded_best.epoch) ASSERT.assertDictEqual(best.train_metric, loaded_best.train_metric) ASSERT.assertDictEqual(best.validation_metric, loaded_best.validation_metric) ASSERT.assertEqual(best.train_model_target_metric.name, loaded_best.train_model_target_metric.name) ASSERT.assertEqual(best.train_model_target_metric.value, loaded_best.train_model_target_metric.value) ASSERT.assertEqual(best.validation_model_target_metric.name, loaded_best.validation_model_target_metric.name) ASSERT.assertEqual(best.validation_model_target_metric.value, loaded_best.validation_model_target_metric.value)
def test_glove_loader(): pretrained_file_path = "data/easytext/tests/pretrained/word_embedding_sample.3d.txt" pretrained_file_path = os.path.join(ROOT_PATH, pretrained_file_path) glove_loader = GloveLoader(embedding_dim=3, pretrained_file_path=pretrained_file_path) embedding_dict = glove_loader.load() expect_embedding_dict = { "a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0], "美丽": [7.0, 8.0, 9.0] } ASSERT.assertDictEqual(expect_embedding_dict, embedding_dict) ASSERT.assertEqual(glove_loader.embedding_dim, 3)
def test_synchronized_data(): """ 测试 to_synchronized_data 和 from_synchronized_data :return: """ demo_metric = _DemoF1Metric() sync_data, op = demo_metric.to_synchronized_data() true_positives = sync_data["true_positives"] false_positives = sync_data["false_positives"] false_negatives = sync_data["false_negatives"] expect_values = [v for _, v in demo_metric.true_positives.items()] ASSERT.assertListEqual(expect_values, true_positives.tolist()) expect_values = [v for _, v in demo_metric.false_positives.items()] ASSERT.assertListEqual(expect_values, false_positives.tolist()) expect_values = [v for _, v in demo_metric._false_negatives.items()] ASSERT.assertListEqual(expect_values, false_negatives.tolist()) expect_true_positives = dict(demo_metric.true_positives) expect_false_positives = dict(demo_metric.false_positives) expect_false_negatives = dict(demo_metric.false_negatives) demo_metric.from_synchronized_data(sync_data=sync_data, reduce_op=op) ASSERT.assertDictEqual(expect_true_positives, demo_metric.true_positives) ASSERT.assertDictEqual(expect_false_positives, demo_metric.false_positives) ASSERT.assertDictEqual(expect_false_negatives, demo_metric.false_negatives)
def _run_train(cuda_devices: List[str] = None): serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load") if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) model = ModelDemo() optimizer_factory = _DemoOptimizerFactory() loss = _DemoLoss() metric = _DemoMetric() trainer = Trainer(num_epoch=100, model=model, loss=loss, metrics=metric, optimizer_factory=optimizer_factory, serialize_dir=serialize_dir, patient=20, num_check_point_keep=25, cuda_devices=cuda_devices) train_dataset = _DemoDataset() train_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, shuffle=False, num_workers=0) validation_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, shuffle=False, num_workers=0) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) expect_model_state_dict = json.loads(json2str(trainer.model.state_dict())) expect_optimizer_state_dict = json.loads( json2str(trainer.optimizer.state_dict())) expect_current_epoch = trainer.current_epoch expect_num_epoch = trainer.num_epoch expect_metric = trainer.metrics.metric[0] expect_metric_tracker = json.loads(json2str(trainer.metric_tracker)) trainer.load_checkpoint(serialize_dir=serialize_dir) loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict())) loaded_optimizer_state_dict = json.loads( json2str(trainer.optimizer.state_dict())) current_epoch = trainer.current_epoch num_epoch = trainer.num_epoch metric = trainer.metrics.metric[0] metric_tracker = json.loads(json2str(trainer.metric_tracker)) ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict) ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict) ASSERT.assertEqual(expect_current_epoch, current_epoch) ASSERT.assertEqual(expect_num_epoch, num_epoch) ASSERT.assertDictEqual(expect_metric, metric) ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)
def _run_train(device: torch.device, is_distributed: bool): serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load") if is_distributed: if TorchDist.get_rank() == 0: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) TorchDist.barrier() else: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) model = ModelDemo() optimizer_factory = _DemoOptimizerFactory() loss = _DemoLoss() metric = _DemoMetric() tensorboard_log_dir = "data/tensorboard" tensorboard_log_dir = os.path.join(ROOT_PATH, tensorboard_log_dir) # shutil.rmtree(tensorboard_log_dir) trainer = Trainer(num_epoch=100, model=model, loss=loss, metrics=metric, optimizer_factory=optimizer_factory, serialize_dir=serialize_dir, patient=20, num_check_point_keep=25, device=device, trainer_callback=None, is_distributed=is_distributed ) logging.info(f"test is_distributed: {is_distributed}") # trainer_callback = BasicTrainerCallbackComposite(tensorboard_log_dir=tensorboard_log_dir) train_dataset = _DemoDataset() if is_distributed: sampler = DistributedSampler(dataset=train_dataset) else: sampler = None train_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, num_workers=0, sampler=sampler) if is_distributed: sampler = DistributedSampler(dataset=train_dataset) else: sampler = None validation_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, num_workers=0, sampler=sampler) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) expect_model_state_dict = json.loads(json2str(trainer.model.state_dict())) expect_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict())) expect_current_epoch = trainer.current_epoch expect_num_epoch = trainer.num_epoch expect_metric = trainer.metrics.metric[0] expect_metric_tracker = json.loads(json2str(trainer.metric_tracker)) trainer.load_checkpoint(serialize_dir=serialize_dir) loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict())) loaded_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict())) current_epoch = trainer.current_epoch num_epoch = trainer.num_epoch metric = trainer.metrics.metric[0] metric_tracker = json.loads(json2str(trainer.metric_tracker)) ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict) ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict) ASSERT.assertEqual(expect_current_epoch, current_epoch) ASSERT.assertEqual(expect_num_epoch, num_epoch) ASSERT.assertDictEqual(expect_metric, metric) ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)