def test_new_events(): def update(*args, **kwargs): pass engine = Engine(update) cpe = CustomPeriodicEvent(n_iterations=5) cpe.attach(engine) assert hasattr(cpe, "Events") assert hasattr(cpe.Events, "ITERATIONS_5_STARTED") assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED") assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED") assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED") cpe = CustomPeriodicEvent(n_epochs=5) cpe.attach(engine) assert hasattr(cpe, "Events") assert hasattr(cpe.Events, "EPOCHS_5_STARTED") assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED") assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED") assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED")
def test_new_events(): def update(*args, **kwargs): pass with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): engine = Engine(update) cpe = CustomPeriodicEvent(n_iterations=5) cpe.attach(engine) assert hasattr(cpe, "Events") assert hasattr(cpe.Events, "ITERATIONS_5_STARTED") assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED") assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED") assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED") with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): cpe = CustomPeriodicEvent(n_epochs=5) cpe.attach(engine) assert hasattr(cpe, "Events") assert hasattr(cpe.Events, "EPOCHS_5_STARTED") assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED") assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED") assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED")
def test_bad_input(): with pytest.raises(ValueError): CustomPeriodicEvent(n_iterations="a") with pytest.raises(ValueError): CustomPeriodicEvent(n_iterations=0) with pytest.raises(ValueError): CustomPeriodicEvent(n_iterations=10.0) with pytest.raises(ValueError): CustomPeriodicEvent(n_epochs="a") with pytest.raises(ValueError): CustomPeriodicEvent(n_epochs=0) with pytest.raises(ValueError): CustomPeriodicEvent(n_epochs=10.0) with pytest.raises(ValueError): CustomPeriodicEvent() with pytest.raises(ValueError): CustomPeriodicEvent(n_iterations=1, n_epochs=2)
def test_bad_input(): with pytest.warns(DeprecationWarning, match=r"CustomPeriodicEvent is deprecated"): with pytest.raises(TypeError, match="Argument n_iterations should be an integer"): CustomPeriodicEvent(n_iterations="a") with pytest.raises(ValueError, match="Argument n_iterations should be positive"): CustomPeriodicEvent(n_iterations=0) with pytest.raises(TypeError, match="Argument n_iterations should be an integer"): CustomPeriodicEvent(n_iterations=10.0) with pytest.raises(TypeError, match="Argument n_epochs should be an integer"): CustomPeriodicEvent(n_epochs="a") with pytest.raises(ValueError, match="Argument n_epochs should be positive"): CustomPeriodicEvent(n_epochs=0) with pytest.raises(TypeError, match="Argument n_epochs should be an integer"): CustomPeriodicEvent(n_epochs=10.0) with pytest.raises( ValueError, match="Either n_iterations or n_epochs should be defined"): CustomPeriodicEvent() with pytest.raises( ValueError, match="Either n_iterations or n_epochs should be defined"): CustomPeriodicEvent(n_iterations=1, n_epochs=2)
def test_integration_epochs(): def update(*args, **kwargs): pass engine = Engine(update) n_epochs = 3 cpe = CustomPeriodicEvent(n_epochs=n_epochs) cpe.attach(engine) data = list(range(16)) custom_period = [1] @engine.on(cpe.Events.EPOCHS_3_STARTED) def on_my_epoch_started(engine): assert (engine.state.epoch - 1) % n_epochs == 0 assert engine.state.epochs_3 == custom_period[0] @engine.on(cpe.Events.EPOCHS_3_COMPLETED) def on_my_epoch_ended(engine): assert engine.state.epoch % n_epochs == 0 assert engine.state.epochs_3 == custom_period[0] custom_period[0] += 1 engine.run(data, max_epochs=10) assert custom_period[0] == 4
def _test(n_iterations, max_epochs, n_iters_per_epoch): def update(*args, **kwargs): pass engine = Engine(update) with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): cpe = CustomPeriodicEvent(n_iterations=n_iterations) cpe.attach(engine) data = list(range(n_iters_per_epoch)) custom_period = [0] n_calls_iter_started = [0] n_calls_iter_completed = [0] event_started = getattr(cpe.Events, "ITERATIONS_{}_STARTED".format(n_iterations)) @engine.on(event_started) def on_my_event_started(engine): assert (engine.state.iteration - 1) % n_iterations == 0 custom_period[0] += 1 custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations)) assert custom_iter == custom_period[0] n_calls_iter_started[0] += 1 event_completed = getattr( cpe.Events, "ITERATIONS_{}_COMPLETED".format(n_iterations)) @engine.on(event_completed) def on_my_event_ended(engine): assert engine.state.iteration % n_iterations == 0 custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations)) assert custom_iter == custom_period[0] n_calls_iter_completed[0] += 1 engine.run(data, max_epochs=max_epochs) n = len(data) * max_epochs / n_iterations nf = math.floor(n) assert custom_period[0] == n_calls_iter_started[0] assert n_calls_iter_started[0] == nf + 1 if nf < n else nf assert n_calls_iter_completed[0] == nf
def get_loops(self): if self.mode == l.TRAIN: train_loop = get_loop(self, self.mode) val_loop = get_loop(self, l.VAL) visualize_loop = get_loop(self, l.VISUALIZE) train_config = self.metric_config[l.TRAIN] train_config[l.MODE] = l.TRAIN val_config = self.metric_config[l.VAL] val_config[l.MODE] = l.VAL visualize_config = self.metric_config[l.VISUALIZE] self.criterion_chain.attach(train_loop.engine, train_config) self.criterion_chain.attach(val_loop.engine, val_config) self.model.attach(train_loop.engine, train_config) self.model.attach(val_loop.engine, val_config) # TODO. Add to the heart of the container. Calculate only during test with measure_time flag # AveragePeriodicMetric(TimeTransformer()).attach(engine, n.FORWARD_TIME) train_loss_event = CustomPeriodicEvent( n_iterations=train_config[exp.LOSS_LOG_ITER]) train_metric_event = CustomPeriodicEvent( n_iterations=train_config[exp.METRIC_LOG_ITER]) val_event = CustomPeriodicEvent( n_iterations=val_config[exp.LOG_ITER]) visualize_event = CustomPeriodicEvent( n_epochs=visualize_config[exp.LOG_EPOCH]) train_loss_event.attach(train_loop.engine) train_metric_event.attach(train_loop.engine) val_event.attach(train_loop.engine) visualize_event.attach(train_loop.engine) @train_loop.engine.on(train_loss_event._periodic_event_completed) def on_train_loss(engine): plot_losses_tensorboard(self.writer, engine, engine, l.TRAIN) @train_loop.engine.on(train_metric_event._periodic_event_completed) def on_train_metric(engine): plot_metrics_tensorboard(self.writer, engine, engine, l.TRAIN) @train_loop.engine.on(val_event._periodic_event_completed) def on_val(engine): val_loop.run(1) plot_losses_tensorboard(self.writer, val_loop.engine, engine, l.VAL) plot_metrics_tensorboard(self.writer, val_loop.engine, engine, l.VAL) @train_loop.engine.on(visualize_event._periodic_event_completed) def on_visualize(engine): visualize_loop.run(1) plot_scores(self.writer, engine, visualize_loop.engine, (eu.SCORE1, eu.SCORE2)) plot_scores(self.writer, engine, visualize_loop.engine, (eu.SAL_SCORE1, eu.SAL_SCORE2)) plot_scores(self.writer, engine, visualize_loop.engine, (eu.CONF_SCORE1, eu.CONF_SCORE2), True) plot_kp_matches(self.writer, engine, visualize_loop.engine, visualize_config[exp.PX_THRESH]) plot_desc_matches(self.writer, engine, visualize_loop.engine, visualize_config[exp.PX_THRESH], DescriptorDistance.INV_COS_SIM) return [train_loop, val_loop, visualize_loop] elif self.mode == l.TEST: test_loop = get_loop(self, self.mode) test_config = self.metric_config[l.TEST] test_config[l.MODE] = l.TEST test_config[du.DATASET_NAME] = list( self.dataset_config[self.mode].keys()) self.model.attach(test_loop.engine, test_config) if du.AACHEN in test_config[du.DATASET_NAME]: @test_loop.engine.on(Events.ITERATION_COMPLETED) def on_iteration_completed(engine): save_aachen_inference(self.dataset_config[self.mode], engine.state.output) else: @test_loop.engine.on(Events.EPOCH_COMPLETED) def on_epoch(engine): logs = join_logs(engine) print_summary(logs, test_config) test_log_to_csv(self.log_dir, logs, test_config, self.model_config, test_config[du.DATASET_NAME]) return [test_loop] elif self.mode == l.ANALYZE: analyze_loop = get_loop(self, self.mode) return [analyze_loop] else: return None