def test_progress(): buffer = BufferOutput() progress = ProgressBar(range(100), file=buffer) progress.start() progress.update(1) time.sleep(0.001) progress.update(2) assert re.match(r' 3%\|█▏ \| 3/100 \[[0-9]+\.[0-9][0-9] it/sec\]', buffer.getvalue())
def test_buffero_stats(): buffer = BufferOutput() print("Accuracy: -", file=buffer) print("Validation Accuracy: -", file=buffer) print("", file=buffer) print("Training ...", file=buffer) buffer.write("\033[4A") print("Accuracy: 0.73", file=buffer) print("Validation Accuracy: 0.70", file=buffer) assert buffer.getvalue() == "Accuracy: 0.73\nValidation Accuracy: 0.70\n\nTraining ...\n"
def test_default_progress(): buffer = BufferOutput() display = Display(stdout=buffer, stderr=buffer) progress = display.progressbar(range(100), epochs=10, file=buffer) progress.start() assert buffer.getvalue() == 'Epoch 1/10|▎ | 1/100 [ - it/sec]' progress.update(24) assert buffer.getvalue() == 'Epoch 1/10|████████▎ | 25/100 [ - it/sec]'
def test_training(tmpdir): d1 = tmpdir.mkdir("p1") buffer = BufferOutput() display = Display(stdout=buffer, stderr=buffer) env = Environment(project_dir=str(d1), display=display) _name = env.start_training() callback = env.progress_callback(epochs=10, steps=100) callback(0, 1, acc=0.56, loss=1.234) callback(0, 2, acc=0.56, loss=1.234, val_acc=0.77) callback(1, 1, acc=0.56, loss=1.234) callback(9, 99, acc=0.78, loss=0.837, val_acc=0.67) env.end_training(final_results=dict(val_acc=0.77)) assert env.get('results.val_acc') == 0.77 with open(os.path.join(env.AI_dir(), "data.yaml")) as f: data_yaml = yaml.load(f) with open(os.path.join(env.stats_dir(), "stats.csv")) as f: stats_csv = f.read() assert data_yaml['results']['status'] == 'FINISHED' assert data_yaml['results']['val_acc'] == 0.77 assert data_yaml['results']['acc'] == 0.78 assert stats_csv == """\
def test_buffero_too_far_down(): buffer = BufferOutput() buffer.write("Hello World!\033[12B!!\n") assert buffer.getvalue() == "Hello World!!!\n"
def test_buffero_cr(): buffer = BufferOutput() buffer.write("Hello World!\rHa\n") assert buffer.getvalue() == "Hallo World!\n"
def test_buffero_multiline(): buffer = BufferOutput() print("Hello World 1!", file=buffer) print("Hello World 2!", file=buffer) print("Hello World 3!", file=buffer) assert buffer.getvalue() == "Hello World 1!\nHello World 2!\nHello World 3!\n"
def test_buffero_long(): buffer = BufferOutput() print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-YYYYYYYYYYYYYYYYYYYYY", file=buffer) assert buffer.getvalue() == "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n"
def test_buffero(): buffer = BufferOutput() print("Hello World!", file=buffer) assert buffer.getvalue() == "Hello World!\n"
def test_default_table(): buffer = BufferOutput() display = Display(stdout=buffer, stderr=buffer) table = display.table([["Accuracy", "Val Accuracy", "Loss", "Val Loss"], [0.89, 0.88, 0.213, 0.334], [0.23, 0.89, 0.001, 0.003]], separate='none') assert str(table) == """\