def test_early_stop(self): class EarlyStopP(SupervisedPlugin): def after_training_iteration( self, strategy: "SupervisedTemplate", **kwargs ): if strategy.clock.train_epoch_iterations == 10: strategy.stop_training() model = SimpleMLP(input_size=6, hidden_size=100) criterion = CrossEntropyLoss() optimizer = SGD(model.parameters(), lr=1) strategy = Cumulative( model, optimizer, criterion, train_mb_size=1, device=get_device(), eval_mb_size=512, train_epochs=1, evaluator=None, plugins=[EarlyStopP()], ) benchmark = get_fast_benchmark() for train_batch_info in benchmark.train_stream: strategy.train(train_batch_info) assert strategy.clock.train_epoch_iterations == 11
def test_cumulative(self): # SIT scenario model, optimizer, criterion, my_nc_benchmark = self.init_sit() strategy = Cumulative( model, optimizer, criterion, train_mb_size=64, device=self.device, eval_mb_size=50, train_epochs=2, ) self.run_strategy(my_nc_benchmark, strategy) # MT scenario strategy = Cumulative( model, optimizer, criterion, train_mb_size=64, device=self.device, eval_mb_size=50, train_epochs=2, ) benchmark = self.load_benchmark(use_task_labels=True) self.run_strategy(benchmark, strategy)
def test_multihead_cumulative(self): # check that multi-head reaches high enough accuracy. # Ensure nothing weird is happening with the multiple heads. set_deterministic_run(seed=0) model = MHTestMLP(input_size=6, hidden_size=100) criterion = CrossEntropyLoss() optimizer = SGD(model.parameters(), lr=1) main_metric = StreamAccuracy() exp_acc = ExperienceAccuracy() evalp = EvaluationPlugin(main_metric, exp_acc, loggers=None) strategy = Cumulative( model, optimizer, criterion, train_mb_size=64, device=get_device(), eval_mb_size=512, train_epochs=6, evaluator=evalp, ) benchmark = get_fast_benchmark(use_task_labels=True) for train_batch_info in benchmark.train_stream: strategy.train(train_batch_info) strategy.eval(benchmark.train_stream[:]) print("TRAIN STREAM ACC: ", main_metric.result()) assert (sum(main_metric.result().values()) / float(len(main_metric.result().keys())) > 0.7)