Пример #1
0
def _run_train(cuda_devices: List[str] = None):
    serialize_dir = os.path.join(ROOT_PATH,
                                 "data/easytext/tests/trainer/save_and_load")

    if os.path.isdir(serialize_dir):
        shutil.rmtree(serialize_dir)

    os.makedirs(serialize_dir)

    model = ModelDemo()

    optimizer_factory = _DemoOptimizerFactory()

    loss = _DemoLoss()
    metric = _DemoMetric()

    trainer = Trainer(num_epoch=100,
                      model=model,
                      loss=loss,
                      metrics=metric,
                      optimizer_factory=optimizer_factory,
                      serialize_dir=serialize_dir,
                      patient=20,
                      num_check_point_keep=25,
                      cuda_devices=cuda_devices)

    train_dataset = _DemoDataset()

    train_data_loader = DataLoader(dataset=train_dataset,
                                   collate_fn=_DemoCollate(),
                                   batch_size=200,
                                   shuffle=False,
                                   num_workers=0)

    validation_data_loader = DataLoader(dataset=train_dataset,
                                        collate_fn=_DemoCollate(),
                                        batch_size=200,
                                        shuffle=False,
                                        num_workers=0)

    trainer.train(train_data_loader=train_data_loader,
                  validation_data_loader=validation_data_loader)

    expect_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    expect_optimizer_state_dict = json.loads(
        json2str(trainer.optimizer.state_dict()))
    expect_current_epoch = trainer.current_epoch
    expect_num_epoch = trainer.num_epoch
    expect_metric = trainer.metrics.metric[0]
    expect_metric_tracker = json.loads(json2str(trainer.metric_tracker))

    trainer.load_checkpoint(serialize_dir=serialize_dir)

    loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    loaded_optimizer_state_dict = json.loads(
        json2str(trainer.optimizer.state_dict()))
    current_epoch = trainer.current_epoch
    num_epoch = trainer.num_epoch
    metric = trainer.metrics.metric[0]
    metric_tracker = json.loads(json2str(trainer.metric_tracker))

    ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict)
    ASSERT.assertDictEqual(expect_optimizer_state_dict,
                           loaded_optimizer_state_dict)
    ASSERT.assertEqual(expect_current_epoch, current_epoch)
    ASSERT.assertEqual(expect_num_epoch, num_epoch)
    ASSERT.assertDictEqual(expect_metric, metric)
    ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)
Пример #2
0
def _run_train(device: torch.device, is_distributed: bool):

    serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load")
            
    if is_distributed:
        if TorchDist.get_rank() == 0:
            if os.path.isdir(serialize_dir):
                shutil.rmtree(serialize_dir)

            os.makedirs(serialize_dir)
    
        TorchDist.barrier()
    else:
        if os.path.isdir(serialize_dir):
            shutil.rmtree(serialize_dir)

        os.makedirs(serialize_dir)

    model = ModelDemo()

    optimizer_factory = _DemoOptimizerFactory()

    loss = _DemoLoss()
    metric = _DemoMetric()

    tensorboard_log_dir = "data/tensorboard"

    tensorboard_log_dir = os.path.join(ROOT_PATH, tensorboard_log_dir)

    # shutil.rmtree(tensorboard_log_dir)

    trainer = Trainer(num_epoch=100,
                      model=model,
                      loss=loss,
                      metrics=metric,
                      optimizer_factory=optimizer_factory,
                      serialize_dir=serialize_dir,
                      patient=20,
                      num_check_point_keep=25,
                      device=device,
                      trainer_callback=None,
                      is_distributed=is_distributed
                      )
    logging.info(f"test is_distributed: {is_distributed}")
    # trainer_callback = BasicTrainerCallbackComposite(tensorboard_log_dir=tensorboard_log_dir)
    train_dataset = _DemoDataset()

    if is_distributed:
        sampler = DistributedSampler(dataset=train_dataset)
    else:
        sampler = None

    train_data_loader = DataLoader(dataset=train_dataset,
                                   collate_fn=_DemoCollate(),
                                   batch_size=200,
                                   num_workers=0,
                                   sampler=sampler)

    if is_distributed:
        sampler = DistributedSampler(dataset=train_dataset)
    else:
        sampler = None

    validation_data_loader = DataLoader(dataset=train_dataset,
                                        collate_fn=_DemoCollate(),
                                        batch_size=200,
                                        num_workers=0,
                                        sampler=sampler)

    trainer.train(train_data_loader=train_data_loader,
                  validation_data_loader=validation_data_loader)

    expect_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    expect_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict()))
    expect_current_epoch = trainer.current_epoch
    expect_num_epoch = trainer.num_epoch
    expect_metric = trainer.metrics.metric[0]
    expect_metric_tracker = json.loads(json2str(trainer.metric_tracker))

    trainer.load_checkpoint(serialize_dir=serialize_dir)

    loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict()))
    loaded_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict()))
    current_epoch = trainer.current_epoch
    num_epoch = trainer.num_epoch
    metric = trainer.metrics.metric[0]
    metric_tracker = json.loads(json2str(trainer.metric_tracker))

    ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict)
    ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict)
    ASSERT.assertEqual(expect_current_epoch, current_epoch)
    ASSERT.assertEqual(expect_num_epoch, num_epoch)
    ASSERT.assertDictEqual(expect_metric, metric)
    ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)