def test_logging(): with TemporaryDirectory() as tmpdir: test_data = _get_test_data() file_name = 'test_log' output_dir = os.path.join(tmpdir, file_name) net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, metrics=acc) train_metrics, val_metrics = est.prepare_loss_and_metrics() logging_handler = event_handler.LoggingHandler( file_name=file_name, file_location=tmpdir, train_metrics=train_metrics, val_metrics=val_metrics) est.fit(test_data, event_handlers=[logging_handler], epochs=3) assert logging_handler.batch_index == 0 assert logging_handler.current_epoch == 3 assert os.path.isfile(output_dir)
def test_logging(): with TemporaryDirectory() as tmpdir: test_data = _get_test_data() file_name = 'test_log' output_dir = os.path.join(tmpdir, file_name) net = _get_test_network() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.gluon.metric.Accuracy() est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc) est.logger.addHandler(logging.FileHandler(output_dir)) train_metrics = est.train_metrics val_metrics = est.val_metrics logging_handler = event_handler.LoggingHandler(metrics=train_metrics) est.fit(test_data, event_handlers=[logging_handler], epochs=3) assert logging_handler.batch_index == 0 assert logging_handler.current_epoch == 3 assert os.path.isfile(output_dir) del est # Clean up estimator and logger before deleting tmpdir