コード例 #1
0
def test_metric_tracker_patient():
    metric_tracker = MetricTracker(patient=1)

    for metric in METRICS:
        metric_tracker.add_metric(**metric)

        if metric["epoch"] > 4:
            ASSERT.assertTrue(metric_tracker.early_stopping(metric["epoch"]))
        else:
            ASSERT.assertFalse(metric_tracker.early_stopping(metric["epoch"]))

        if metric_tracker.early_stopping(metric["epoch"]):
            break

    expect = {"epoch": 3,
              "train_metric": {"acc": 0.85},
              "train_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.85),
              "validation_metric": {"acc": 0.60},
              "validation_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.60)}

    best = metric_tracker.best()
    ASSERT.assertEqual(expect["epoch"], best.epoch)

    ASSERT.assertDictEqual(expect["train_metric"], best.train_metric)
    ASSERT.assertDictEqual(expect["validation_metric"], best.validation_metric)
    ASSERT.assertEqual(expect["train_model_target_metric"].name, best.train_model_target_metric.name)
    ASSERT.assertEqual(expect["train_model_target_metric"].value, best.train_model_target_metric.value)
    ASSERT.assertEqual(expect["validation_model_target_metric"].name, best.validation_model_target_metric.name)
    ASSERT.assertEqual(expect["validation_model_target_metric"].value, best.validation_model_target_metric.value)
コード例 #2
0
def test_metric_tracker_best():
    """
    测试 metric tracker
    :return:
    """
    metric_tracker = MetricTracker(patient=None)

    for metric in METRICS:
        metric_tracker.add_metric(**metric)

    expect = {"epoch": 3,
              "train_metric": {"acc": 0.85},
              "train_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.85),
              "validation_metric": {"acc": 0.60},
              "validation_model_target_metric": ModelTargetMetric(metric_name="acc", metric_value=0.60)}

    best = metric_tracker.best()
    ASSERT.assertEqual(expect["epoch"], best.epoch)

    ASSERT.assertDictEqual(expect["train_metric"], best.train_metric)
    ASSERT.assertDictEqual(expect["validation_metric"], best.validation_metric)
    ASSERT.assertEqual(expect["train_model_target_metric"].name, best.train_model_target_metric.name)
    ASSERT.assertEqual(expect["train_model_target_metric"].value, best.train_model_target_metric.value)
    ASSERT.assertEqual(expect["validation_model_target_metric"].name, best.validation_model_target_metric.name)
    ASSERT.assertEqual(expect["validation_model_target_metric"].value, best.validation_model_target_metric.value)
コード例 #3
0
def test_metric_tracker_save_and_load():
    metric_tracker = MetricTracker(patient=1)

    for metric in METRICS:
        metric_tracker.add_metric(**metric)

        if metric["epoch"] > 4:
            ASSERT.assertTrue(metric_tracker.early_stopping(metric["epoch"]))
        else:
            ASSERT.assertFalse(metric_tracker.early_stopping(metric["epoch"]))

        if metric_tracker.early_stopping(metric["epoch"]):
            break

    saved_file_path = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/metric_tracker.json")

    metric_tracker.save(saved_file_path)

    loaded_metric_tracker = MetricTracker.from_file(saved_file_path)

    best = metric_tracker.best()
    loaded_best = loaded_metric_tracker.best()
    ASSERT.assertEqual(best.epoch, loaded_best.epoch)

    ASSERT.assertDictEqual(best.train_metric, loaded_best.train_metric)
    ASSERT.assertDictEqual(best.validation_metric, loaded_best.validation_metric)
    ASSERT.assertEqual(best.train_model_target_metric.name, loaded_best.train_model_target_metric.name)
    ASSERT.assertEqual(best.train_model_target_metric.value, loaded_best.train_model_target_metric.value)
    ASSERT.assertEqual(best.validation_model_target_metric.name, loaded_best.validation_model_target_metric.name)
    ASSERT.assertEqual(best.validation_model_target_metric.value, loaded_best.validation_model_target_metric.value)
コード例 #4
0
def test_glove_loader():
    pretrained_file_path = "data/easytext/tests/pretrained/word_embedding_sample.3d.txt"
    pretrained_file_path = os.path.join(ROOT_PATH, pretrained_file_path)

    glove_loader = GloveLoader(embedding_dim=3,
                               pretrained_file_path=pretrained_file_path)

    embedding_dict = glove_loader.load()
    expect_embedding_dict = {
        "a": [1.0, 2.0, 3.0],
        "b": [4.0, 5.0, 6.0],
        "美丽": [7.0, 8.0, 9.0]
    }

    ASSERT.assertDictEqual(expect_embedding_dict, embedding_dict)
    ASSERT.assertEqual(glove_loader.embedding_dim, 3)
コード例 #5
0
ファイル: test_f1_metric.py プロジェクト: piaoxue88/easytext
def test_synchronized_data():
    """
    测试 to_synchronized_data 和 from_synchronized_data
    :return:
    """

    demo_metric = _DemoF1Metric()

    sync_data, op = demo_metric.to_synchronized_data()

    true_positives = sync_data["true_positives"]
    false_positives = sync_data["false_positives"]
    false_negatives = sync_data["false_negatives"]

    expect_values = [v for _, v in demo_metric.true_positives.items()]
    ASSERT.assertListEqual(expect_values, true_positives.tolist())

    expect_values = [v for _, v in demo_metric.false_positives.items()]
    ASSERT.assertListEqual(expect_values, false_positives.tolist())

    expect_values = [v for _, v in demo_metric._false_negatives.items()]
    ASSERT.assertListEqual(expect_values, false_negatives.tolist())

    expect_true_positives = dict(demo_metric.true_positives)
    expect_false_positives = dict(demo_metric.false_positives)
    expect_false_negatives = dict(demo_metric.false_negatives)

    demo_metric.from_synchronized_data(sync_data=sync_data, reduce_op=op)

    ASSERT.assertDictEqual(expect_true_positives, demo_metric.true_positives)
    ASSERT.assertDictEqual(expect_false_positives, demo_metric.false_positives)
    ASSERT.assertDictEqual(expect_false_negatives, demo_metric.false_negatives)
コード例 #6
0
def _run_train(cuda_devices: List[str] = None):
    serialize_dir = os.path.join(ROOT_PATH,
                                 "data/easytext/tests/trainer/save_and_load")

    if os.path.isdir(serialize_dir):
        shutil.rmtree(serialize_dir)

    os.makedirs(serialize_dir)

    model = ModelDemo()

    optimizer_factory = _DemoOptimizerFactory()

    loss = _DemoLoss()
    metric = _DemoMetric()

    trainer = Trainer(num_epoch=100,
                      model=model,
                      loss=loss,
                      metrics=metric,
                      optimizer_factory=optimizer_factory,
                      serialize_dir=serialize_dir,
                      patient=20,
                      num_check_point_keep=25,
                      cuda_devices=cuda_devices)

    train_dataset = _DemoDataset()

    train_data_loader = DataLoader(dataset=train_dataset,
                                   collate_fn=_DemoCollate(),
                                   batch_size=200,
                                   shuffle=False,
                                   num_workers=0)

    validation_data_loader = DataLoader(dataset=train_dataset,
                                        collate_fn=_DemoCollate(),
                                        batch_size=200,
                                        shuffle=False,
                                        num_workers=0)

    trainer.train(train_data_loader=train_data_loader,
                  validation_data_loader=validation_data_loader)

    expect_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    expect_optimizer_state_dict = json.loads(
        json2str(trainer.optimizer.state_dict()))
    expect_current_epoch = trainer.current_epoch
    expect_num_epoch = trainer.num_epoch
    expect_metric = trainer.metrics.metric[0]
    expect_metric_tracker = json.loads(json2str(trainer.metric_tracker))

    trainer.load_checkpoint(serialize_dir=serialize_dir)

    loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    loaded_optimizer_state_dict = json.loads(
        json2str(trainer.optimizer.state_dict()))
    current_epoch = trainer.current_epoch
    num_epoch = trainer.num_epoch
    metric = trainer.metrics.metric[0]
    metric_tracker = json.loads(json2str(trainer.metric_tracker))

    ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict)
    ASSERT.assertDictEqual(expect_optimizer_state_dict,
                           loaded_optimizer_state_dict)
    ASSERT.assertEqual(expect_current_epoch, current_epoch)
    ASSERT.assertEqual(expect_num_epoch, num_epoch)
    ASSERT.assertDictEqual(expect_metric, metric)
    ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)
コード例 #7
0
def _run_train(device: torch.device, is_distributed: bool):

    serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load")
            
    if is_distributed:
        if TorchDist.get_rank() == 0:
            if os.path.isdir(serialize_dir):
                shutil.rmtree(serialize_dir)

            os.makedirs(serialize_dir)
    
        TorchDist.barrier()
    else:
        if os.path.isdir(serialize_dir):
            shutil.rmtree(serialize_dir)

        os.makedirs(serialize_dir)

    model = ModelDemo()

    optimizer_factory = _DemoOptimizerFactory()

    loss = _DemoLoss()
    metric = _DemoMetric()

    tensorboard_log_dir = "data/tensorboard"

    tensorboard_log_dir = os.path.join(ROOT_PATH, tensorboard_log_dir)

    # shutil.rmtree(tensorboard_log_dir)

    trainer = Trainer(num_epoch=100,
                      model=model,
                      loss=loss,
                      metrics=metric,
                      optimizer_factory=optimizer_factory,
                      serialize_dir=serialize_dir,
                      patient=20,
                      num_check_point_keep=25,
                      device=device,
                      trainer_callback=None,
                      is_distributed=is_distributed
                      )
    logging.info(f"test is_distributed: {is_distributed}")
    # trainer_callback = BasicTrainerCallbackComposite(tensorboard_log_dir=tensorboard_log_dir)
    train_dataset = _DemoDataset()

    if is_distributed:
        sampler = DistributedSampler(dataset=train_dataset)
    else:
        sampler = None

    train_data_loader = DataLoader(dataset=train_dataset,
                                   collate_fn=_DemoCollate(),
                                   batch_size=200,
                                   num_workers=0,
                                   sampler=sampler)

    if is_distributed:
        sampler = DistributedSampler(dataset=train_dataset)
    else:
        sampler = None

    validation_data_loader = DataLoader(dataset=train_dataset,
                                        collate_fn=_DemoCollate(),
                                        batch_size=200,
                                        num_workers=0,
                                        sampler=sampler)

    trainer.train(train_data_loader=train_data_loader,
                  validation_data_loader=validation_data_loader)

    expect_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    expect_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict()))
    expect_current_epoch = trainer.current_epoch
    expect_num_epoch = trainer.num_epoch
    expect_metric = trainer.metrics.metric[0]
    expect_metric_tracker = json.loads(json2str(trainer.metric_tracker))

    trainer.load_checkpoint(serialize_dir=serialize_dir)

    loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    loaded_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict()))
    current_epoch = trainer.current_epoch
    num_epoch = trainer.num_epoch
    metric = trainer.metrics.metric[0]
    metric_tracker = json.loads(json2str(trainer.metric_tracker))

    ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict)
    ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict)
    ASSERT.assertEqual(expect_current_epoch, current_epoch)
    ASSERT.assertEqual(expect_num_epoch, num_epoch)
    ASSERT.assertDictEqual(expect_metric, metric)
    ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)