Esempio n. 1
0
def test_train_iter2():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 1
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=40.0,
                             top5_err=20.0,
                             mc_opt_err=55.0,
                             mc_q_err=78.0,
                             loss_des=6.3,
                             loss_mc=0.2,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    train_meter.update_stats(top1_err=41.0,
                             top5_err=21.0,
                             mc_opt_err=56.0,
                             mc_q_err=86.0,
                             loss_des=6.7,
                             loss_mc=1.2,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=1, cur_iter=1)
    assert stats['top1_err'] == 41.0
    assert stats['top5_err'] == 21.0
    assert stats['mc_opt_err'] == 56.0
    assert stats['mc_q_err'] == 86.0
    assert stats['loss_des'] == 6.7
    assert stats['loss_mc'] == 1.2
Esempio n. 2
0
def test_train_iter4():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 2
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=40.0,
                             top5_err=20.0,
                             mc_opt_err=55.0,
                             mc_q_err=78.0,
                             loss_des=6.3,
                             loss_mc=0.2,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    train_meter.update_stats(top1_err=10.0,
                             top5_err=10.0,
                             mc_opt_err=10.0,
                             mc_q_err=10.0,
                             loss_des=10.0,
                             loss_mc=10.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=1, cur_iter=2)
    assert stats is None
    train_meter.update_stats(top1_err=20.0,
                             top5_err=20.0,
                             mc_opt_err=20.0,
                             mc_q_err=20.0,
                             loss_des=20.0,
                             loss_mc=20.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=0, cur_iter=3)
    assert stats['top1_err'] == 15.0
    assert stats['top5_err'] == 15.0
    assert stats['mc_opt_err'] == 15.0
    assert stats['mc_q_err'] == 15.0
    assert stats['loss_des'] == 15.0
    assert stats['loss_mc'] == 15.0
Esempio n. 3
0
def test_train_iter5():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 3
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=30.0,
                             top5_err=30.0,
                             mc_opt_err=30.0,
                             mc_q_err=30.0,
                             loss_des=30.0,
                             loss_mc=30.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    train_meter.update_stats(top1_err=10.0,
                             top5_err=10.0,
                             mc_opt_err=10.0,
                             mc_q_err=10.0,
                             loss_des=10.0,
                             loss_mc=10.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=1, cur_iter=1)
    assert stats is None
    train_meter.update_stats(top1_err=20.0,
                             top5_err=20.0,
                             mc_opt_err=20.0,
                             mc_q_err=20.0,
                             loss_des=20.0,
                             loss_mc=20.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=10)
    stats = train_meter.log_iter_stats(cur_epoch=0, cur_iter=5)
    assert stats['top1_err'] == 20.0, print(stats['top1_err'])
    assert stats['top5_err'] == 20.0
    assert stats['mc_opt_err'] == 20.0
    assert stats['mc_q_err'] == 20.0
    assert stats['loss_des'] == 20.0
    assert stats['loss_mc'] == 20.0
Esempio n. 4
0
def test_train_epoch_only_des():
    cfg = get_cfg()
    cfg.LOG_PERIOD = 3
    train_meter = ClevrerTrainMeter(5, cfg)
    train_meter.update_stats(top1_err=30.0,
                             top5_err=30.0,
                             mc_opt_err=30.0,
                             mc_q_err=30.0,
                             loss_des=30.0,
                             loss_mc=30.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=0)
    train_meter.update_stats(top1_err=10.0,
                             top5_err=10.0,
                             mc_opt_err=10.0,
                             mc_q_err=10.0,
                             loss_des=10.0,
                             loss_mc=10.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=0)
    train_meter.update_stats(top1_err=20.0,
                             top5_err=20.0,
                             mc_opt_err=20.0,
                             mc_q_err=20.0,
                             loss_des=20.0,
                             loss_mc=20.0,
                             lr=0.1,
                             mb_size_des=10,
                             mb_size_mc=0)
    stats = train_meter.log_epoch_stats(cur_epoch=1)
    assert stats['top1_err'] == 20.0, print(stats['top1_err'])
    assert stats['top5_err'] == 20.0
    assert not 'mc_opt_err' in stats
    assert not 'mc_q_err' in stats
    assert stats['loss_des'] == 20.0
    assert not 'loss_mc' in stats