Ejemplo n.º 1
0
def test_control_C():
    # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试
    from fastNLP import ControlC, Callback
    import time
    
    line1 = "\n\n\n\n\n*************************"
    line2 = "*************************\n\n\n\n\n"
    
    class Wait(Callback):
        def on_epoch_end(self):
            time.sleep(5)
    
    data_set, model = prepare_env()
    
    print(line1 + "Test starts!" + line2)
    trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
                      batch_size=32, n_epochs=20, dev_data=data_set,
                      metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
                      callbacks=[Wait(), ControlC(False)], check_code_level=2)
    trainer.train()
    
    print(line1 + "Program goes on ..." + line2)
    
    trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
                      batch_size=32, n_epochs=20, dev_data=data_set,
                      metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
                      callbacks=[Wait(), ControlC(True)], check_code_level=2)
    trainer.train()
    
    print(line1 + "Test failed!" + line2)
Ejemplo n.º 2
0
    def test_control_C_callback(self):
        class Raise(Callback):
            def on_epoch_end(self):
                raise KeyboardInterrupt

        flags = [False]

        def set_flag():
            flags[0] = not flags[0]

        data_set, model = prepare_env()

        trainer = Trainer(data_set,
                          model,
                          optimizer=SGD(lr=0.1),
                          loss=BCELoss(pred="predict", target="y"),
                          batch_size=32,
                          n_epochs=20,
                          dev_data=data_set,
                          metrics=AccuracyMetric(pred="predict", target="y"),
                          use_tqdm=True,
                          callbacks=[Raise(),
                                     ControlC(False, set_flag)],
                          check_code_level=2)
        trainer.train()

        self.assertEqual(flags[0], False)

        trainer = Trainer(data_set,
                          model,
                          optimizer=SGD(lr=0.1),
                          loss=BCELoss(pred="predict", target="y"),
                          batch_size=32,
                          n_epochs=20,
                          dev_data=data_set,
                          metrics=AccuracyMetric(pred="predict", target="y"),
                          use_tqdm=True,
                          callbacks=[Raise(),
                                     ControlC(True, set_flag)],
                          check_code_level=2)
        trainer.train()

        self.assertEqual(flags[0], True)
Ejemplo n.º 3
0
 def test_KeyBoardInterrupt(self):
     data_set, model = prepare_env()
     trainer = Trainer(data_set,
                       model,
                       loss=BCELoss(pred="predict", target="y"),
                       n_epochs=5,
                       batch_size=32,
                       print_every=50,
                       optimizer=SGD(lr=0.1),
                       check_code_level=2,
                       use_tqdm=False,
                       callbacks=[ControlC(False)])
     trainer.train()