예제 #1
0
def test_resume_checkpoint():
    with TemporaryDirectory() as tmpdir:
        model_prefix = 'test_net'
        file_path = os.path.join(tmpdir, model_prefix)
        test_data = _get_test_data()

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                             model_prefix=model_prefix,
                                                             monitor=acc,
                                                             max_checkpoints=1)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2)
        assert os.path.isfile(file_path + '-epoch1batch8.params')
        assert os.path.isfile(file_path + '-epoch1batch8.states')
        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                             model_prefix=model_prefix,
                                                             monitor=acc,
                                                             max_checkpoints=1,
                                                             resume_from_checkpoint=True)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5)
        # should only continue to train 3 epochs and last checkpoint file is epoch4
        assert est.max_epoch == 3
        assert os.path.isfile(file_path + '-epoch4batch20.states')
예제 #2
0
def test_checkpoint_handler():
    with TemporaryDirectory() as tmpdir:
        model_prefix = 'test_epoch'
        file_path = os.path.join(tmpdir, model_prefix)
        test_data = _get_test_data()

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(
            model_dir=tmpdir,
            model_prefix=model_prefix,
            monitor=acc,
            save_best=True,
            epoch_period=1)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1)
        assert checkpoint_handler.current_epoch == 1
        assert checkpoint_handler.current_batch == 4
        assert os.path.isfile(file_path + '-best.params')
        assert os.path.isfile(file_path + '-best.states')
        assert os.path.isfile(file_path + '-epoch0batch4.params')
        assert os.path.isfile(file_path + '-epoch0batch4.states')

        model_prefix = 'test_batch'
        file_path = os.path.join(tmpdir, model_prefix)
        net = _get_test_network(nn.HybridSequential())
        net.hybridize()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(
            model_dir=tmpdir,
            model_prefix=model_prefix,
            epoch_period=None,
            batch_period=2,
            max_checkpoints=2)
        est.fit(test_data, event_handlers=[checkpoint_handler], batches=10)
        assert checkpoint_handler.current_batch == 10
        assert checkpoint_handler.current_epoch == 3
        assert not os.path.isfile(file_path + 'best.params')
        assert not os.path.isfile(file_path + 'best.states')
        assert not os.path.isfile(file_path + '-epoch0batch0.params')
        assert not os.path.isfile(file_path + '-epoch0batch0.states')
        assert os.path.isfile(file_path + '-symbol.json')
        assert os.path.isfile(file_path + '-epoch1batch7.params')
        assert os.path.isfile(file_path + '-epoch1batch7.states')
        assert os.path.isfile(file_path + '-epoch2batch9.params')
        assert os.path.isfile(file_path + '-epoch2batch9.states')