def test_logging():
    with TemporaryDirectory() as tmpdir:
        test_data = _get_test_data()
        file_name = 'test_log'
        output_dir = os.path.join(tmpdir, file_name)

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
        train_metrics, val_metrics = est.prepare_loss_and_metrics()
        logging_handler = event_handler.LoggingHandler(
            file_name=file_name,
            file_location=tmpdir,
            train_metrics=train_metrics,
            val_metrics=val_metrics)
        est.fit(test_data, event_handlers=[logging_handler], epochs=3)
        assert logging_handler.batch_index == 0
        assert logging_handler.current_epoch == 3
        assert os.path.isfile(output_dir)
示例#2
0
def test_logging():
    with TemporaryDirectory() as tmpdir:
        test_data = _get_test_data()
        file_name = 'test_log'
        output_dir = os.path.join(tmpdir, file_name)

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)

        est.logger.addHandler(logging.FileHandler(output_dir))

        train_metrics = est.train_metrics
        val_metrics = est.val_metrics
        logging_handler = event_handler.LoggingHandler(metrics=train_metrics)
        est.fit(test_data, event_handlers=[logging_handler], epochs=3)
        assert logging_handler.batch_index == 0
        assert logging_handler.current_epoch == 3
        assert os.path.isfile(output_dir)
        del est  # Clean up estimator and logger before deleting tmpdir